diff --git a/.gitignore b/.gitignore index dbd34eb..f79f03f 100644 --- a/.gitignore +++ b/.gitignore @@ -270,4 +270,5 @@ benchmarks/*/data/ benchmarks/*/.env* benchmarks/*/logs/ benchmarks/*/results/ -benchmarks/*/output/ \ No newline at end of file +benchmarks/*/output/ +benchmarks/*/.work/ diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md new file mode 100644 index 0000000..9bb0134 --- /dev/null +++ b/benchmarks/financebench/README.md @@ -0,0 +1,268 @@ +# FinanceBench Benchmark + +FinanceBench evaluation pipeline for **Sirchmunk AgenticSearch**. + +## Overview + +[FinanceBench](https://arxiv.org/abs/2311.11944) is an open-book financial QA benchmark +with **150 expert-annotated questions** across **40+ US public companies** (10-K/10-Q filings). + +### Evaluation Modes + +| Mode | Description | +|------|-------------| +| `singleDoc` | Each question searches only its target PDF (standard) | +| `sharedCorpus` | All questions search the full 41-PDF corpus | + +### Metrics + +- **3-Class Scoring**: Correct / Hallucination / Refusal (per FinanceBench paper) +- **EM / F1**: Exact Match and token-level F1 with financial value normalisation +- **Evidence Recall**: Retrieved pages vs gold evidence pages + +## Prerequisites + +### Step 1: Install Sirchmunk + +Install Sirchmunk from the repository root so that the `sirchmunk` CLI is available: + +```bash +# From repository root +pip install -e . +``` + +Verify the installation: + +```bash +sirchmunk --version +``` + +### Step 2: Prepare Dataset + +Download the [FinanceBench](https://huggingface.co/datasets/PatronusAI/financebench) +dataset and place the files under `benchmarks/financebench/data/`: + +``` +data/ +├── financebench_open_source.jsonl # 150 expert-annotated QA pairs +└── pdfs/ # 41 SEC-filing PDFs (10-K / 10-Q) + ├── 3M_2018_10K.pdf + ├── AMCOR_2023_10K.pdf + └── ... +``` + +Each PDF filename must match the `doc_name` field in the JSONL file. + +### Step 3: Initialize Experiment Workspace + +Initialize an isolated workspace for this experiment. This keeps the knowledge base +and cache separate from the default `~/.sirchmunk`: + +```bash +cd benchmarks/financebench +sirchmunk init --work-path .work +``` + +This creates a `.work/` directory containing a **platform .env** file (`.work/.env`). + +**Configure the platform .env** (`.work/.env`): + +This file controls the LLM provider used by Sirchmunk's search engine. +You **must** set valid LLM credentials here before proceeding. + +| Variable | Required | Description | Example | +|----------|----------|-------------|-----------------------------------------------------| +| `LLM_API_KEY` | **Yes** | API key for the LLM provider | `sk-xxx` | +| `LLM_BASE_URL` | **Yes** | LLM API endpoint | `https://dashscope.aliyuncs.com/compatible-mode/v1` | +| `LLM_MODEL_NAME` | **Yes** | Model name for search & QA | `qwen3.6-plus` | +| `LLM_TIMEOUT` | No | Request timeout in seconds | `120` | + +```bash +# Edit the platform .env +vi .work/.env +``` + +### Step 4: Knowledge Compiling + +Compile the PDF corpus into the experiment workspace so that Sirchmunk can search it: + +```bash +sirchmunk compile --work-path .work --paths data/pdfs +``` + +> **Note:** This step parses, chunks, and indexes all PDFs. +> For FinanceBench's all PDFs, expect hours of processing time, depending on your LLM speed and compute resources. + +#### Shallow Compile (Recommended for First Run) + +Use `--shallow` to skip tree indexing and only generate Summary + Topics. +This reduces LLM calls dramatically and achieves **5–9× speedup**: + +```bash +sirchmunk compile --work-path .work --paths data/pdfs --shallow +``` + +> **Tip:** `--shallow` is ideal for quickly compiling a large corpus on the first pass. +> You can run a normal (full) compile later to incrementally add tree indexes. + +### Step 5: Configure Experiment + +Create the **experiment .env** from the template: + +```bash +cp .env.example .env.financebench +``` + +**Configure the experiment .env** (`.env.financebench`): + +This file controls FinanceBench-specific evaluation parameters. + +#### Dataset Paths + +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_WORK_PATH` | No | Isolated workspace path | `./.work` | +| `FB_DATA_DIR` | **Yes** | Directory containing `financebench_open_source.jsonl` | `./data` | +| `FB_PDF_DIR` | **Yes** | Directory containing the 41 PDF files | `./data/pdfs` | +| `FB_OUTPUT_DIR` | No | Results output directory | `./output` | + +#### Dataset Settings + +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_LIMIT` | No | Number of questions to evaluate (`0` = all 150) | `0` | +| `FB_SEED` | No | Random seed for reproducibility | `42` | + +#### Search Settings + +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_MODE` | No | Search mode: `FAST` or `DEEP` | `FAST` | +| `FB_TOP_K_FILES` | No | Max files returned per search | `5` | +| `FB_MAX_TOKEN_BUDGET` | No | Token budget for search context | `128000` | +| `FB_ENABLE_DIR_SCAN` | No | Enable directory-level scanning | `true` | + +#### Evaluation Settings + +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_EVAL_MODE` | No | `singleDoc` (per-PDF) or `sharedCorpus` (all PDFs) | `singleDoc` | +| `FB_ENABLE_LLM_JUDGE` | No | Enable LLM Judge for semantic equivalence | `true` | +| `FB_EXTRACT_ANSWER` | No | Extract short answer from verbose response | `true` | + +#### Concurrency Settings + +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_MAX_CONCURRENT` | No | Max concurrent evaluation requests | `3` | +| `FB_REQUEST_DELAY` | No | Delay between requests in seconds | `0.5` | + +**Optional LLM Override**: If you want this experiment to use a **different** LLM +than the platform config, uncomment the `LLM_*` lines in `.env.financebench`. +Otherwise, the experiment inherits LLM settings from `.work/.env`. + +```bash +# Edit the experiment .env +vi .env.financebench +``` + +## Configuration Architecture + +Configuration loads with layered inheritance (highest priority wins): + +``` +Priority (highest → lowest): +┌──────────────────────────────────┐ +│ Command-line args │ ← --limit N, --env +├──────────────────────────────────┤ +│ .env.financebench (experiment) │ ← FB_* params + optional LLM override +├──────────────────────────────────┤ +│ .work/.env (platform) │ ← LLM_API_KEY, LLM_MODEL_NAME, etc. +├──────────────────────────────────┤ +│ Environment variables │ ← os.environ fallback +├──────────────────────────────────┤ +│ Defaults │ ← Hard-coded in FinanceBenchConfig +└──────────────────────────────────┘ +``` + +### What Goes Where? + +| Setting | Platform `.work/.env` | Experiment `.env.financebench` | +|---------|:---------------------:|:------------------------------:| +| LLM API Key | ✅ (required) | Only if overriding | +| LLM Model | ✅ (required) | Only if overriding | +| LLM Base URL | ✅ (required) | Only if overriding | +| LLM Timeout | Optional | Only if overriding | +| PDF directory | — | ✅ (required) | +| Data directory | — | ✅ (required) | +| Output directory | — | Optional | +| Eval mode | — | Optional | +| Search mode | — | Optional | +| LLM Judge | — | Optional | +| Concurrency | — | Optional | + +### 1. Run + +```bash +# Run full benchmark (150 questions) +python run_benchmark.py + +# Run with custom config and question limit +python run_benchmark.py --env .env.custom --limit 20 +``` + +### 2. Analyze + +```bash +# Analyze a completed run +python analyze_results.py output/results_YYYYMMDD_HHMMSS.jsonl + +# Show more error cases +python analyze_results.py output/results_*.jsonl --max-errors 50 +``` + +## Data Format + +The dataset file `financebench_open_source.jsonl` contains one JSON object per line: + +```json +{ + "financebench_id": "financebench_id_00001", + "question": "What is the FY2018 capital expenditure amount for 3M?", + "answer": "$1,577.00", + "doc_name": "3M_2018_10K", + "company": "3M", + "question_type": "fact-based-w-numerical-answer", + "question_reasoning": "retrieve", + "evidence": [{"evidence_text": "...", "evidence_page_num": 42}] +} +``` + +## File Structure + +``` +benchmarks/financebench/ +├── .env.example # Config template (copy to .env.financebench) +├── config.py # FinanceBenchConfig dataclass +├── data_loader.py # Dataset + PDF corpus loader +├── evaluate.py # EM/F1/3-class scoring + aggregation +├── runner.py # Async batch runner (AgenticSearch) +├── run_benchmark.py # CLI entry point +├── analyze_results.py # Post-hoc analysis tool +├── data/ +│ ├── financebench_open_source.jsonl +│ └── pdfs/ # 41 SEC-filing PDFs +├── output/ # Results + metrics (auto-created) +└── logs/ # Run logs (auto-created) +``` + +## SOTA Reference + +| System | Accuracy | Coverage | +|--------|----------|----------| +| Mafin 2.5 (SOTA) | 98.7% | 100% | +| Fintool | 98.0% | 66.7% | +| Quantly | 94.0% | 100% | +| GPT-4 (zero-shot) | 29.3% | 100% | + +> Mafin 2.5 uses PageIndex + Agentic Vectorless RAG 3.0 architecture. diff --git a/benchmarks/financebench/analyze_results.py b/benchmarks/financebench/analyze_results.py new file mode 100644 index 0000000..a804284 --- /dev/null +++ b/benchmarks/financebench/analyze_results.py @@ -0,0 +1,316 @@ +"""Analyze FinanceBench benchmark results. + +Read a JSONL results file produced by ``run_benchmark.py`` and print a +comprehensive analysis report including per-type breakdowns, per-company +accuracy, error cases, and a SOTA comparison table. + +Usage: + python analyze_results.py output/results_YYYYMMDD_HHMMSS.jsonl + python analyze_results.py output/results_*.jsonl --max-errors 30 +""" +from __future__ import annotations + +import argparse +import json +import sys +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +from evaluate import compute_metrics + + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- + + +def load_results(path: str) -> List[Dict[str, Any]]: + """Load a JSONL results file into a list of dicts. + + Args: + path: Path to a ``.jsonl`` file where each line is a JSON object. + + Returns: + List of result dicts. + + Raises: + FileNotFoundError: If *path* does not exist. + json.JSONDecodeError: If a line contains invalid JSON. + """ + p = Path(path) + if not p.exists(): + print(f"ERROR: file not found — {path}", file=sys.stderr) + sys.exit(1) + + results: list[dict] = [] + with open(p, encoding="utf-8") as f: + for lineno, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + results.append(json.loads(line)) + except json.JSONDecodeError as exc: + print(f"WARNING: skipping malformed line {lineno}: {exc}", file=sys.stderr) + return results + + +# --------------------------------------------------------------------------- +# Pretty-print helpers +# --------------------------------------------------------------------------- + + +def print_breakdown(title: str, breakdown: Dict[str, Dict[str, Any]]) -> None: + """Pretty-print a metrics breakdown table. + + Args: + title: Section header text. + breakdown: ``{group_name: {accuracy, hallucination_rate, ...}}``. + """ + print(f"\n=== Breakdown by {title} ===\n") + + # Determine if judge data is available + has_judge = any(m.get("llm_judge_accuracy") is not None for m in breakdown.values()) + + if has_judge: + header = f" {'Group':<30} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'Judge%':>7} {'N':>4}" + print(header) + print(" " + "-" * (len(header) - 2)) + + for group, m in sorted(breakdown.items(), key=lambda kv: -kv[1].get("accuracy", 0)): + acc = m.get("accuracy", 0) + hal = m.get("hallucination_rate", 0) + ref = m.get("refusal_rate", 0) + n = m.get("n", 0) + jdg = m.get("llm_judge_accuracy") + jdg_str = f"{jdg:>6.1f}" if jdg is not None else " N/A" + print(f" {group:<30} {acc:>5.1f} {hal:>7.1f} {ref:>7.1f} {jdg_str} {n:>4}") + else: + header = f" {'Group':<30} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}" + print(header) + print(" " + "-" * (len(header) - 2)) + + for group, m in sorted(breakdown.items(), key=lambda kv: -kv[1].get("accuracy", 0)): + acc = m.get("accuracy", 0) + hal = m.get("hallucination_rate", 0) + ref = m.get("refusal_rate", 0) + n = m.get("n", 0) + print(f" {group:<30} {acc:>5.1f} {hal:>7.1f} {ref:>7.1f} {n:>4}") + + +def _compute_company_breakdown( + results: List[Dict[str, Any]], +) -> List[Tuple[str, float, int, int, int]]: + """Group results by company and return sorted by accuracy ascending. + + Returns: + List of ``(company, accuracy, correct, total, halluc)`` tuples, + sorted by accuracy ascending (worst first). + """ + groups: dict[str, list[dict]] = defaultdict(list) + for r in results: + company = r.get("company", "unknown") or "unknown" + groups[company].append(r) + + rows: list[tuple[str, float, int, int, int]] = [] + for company, items in groups.items(): + n = len(items) + correct = sum(1 for r in items if r.get("classification") == "correct") + halluc = sum(1 for r in items if r.get("classification") == "hallucination") + acc = (correct / n * 100) if n else 0.0 + rows.append((company, acc, correct, n, halluc)) + + rows.sort(key=lambda x: x[1]) # worst first + return rows + + +def print_company_breakdown(results: List[Dict[str, Any]], top_n: int = 10) -> None: + """Print per-company accuracy table, showing worst *top_n* companies. + + Args: + results: List of per-question result dicts. + top_n: Number of worst-performing companies to display. + """ + rows = _compute_company_breakdown(results) + if not rows: + return + + print(f"\n=== Worst {top_n} Companies by Accuracy ===\n") + header = f" {'Company':<40} {'Acc%':>6} {'Correct':>8} {'Hallu':>6} {'N':>4}" + print(header) + print(" " + "-" * (len(header) - 2)) + + for company, acc, correct, n, halluc in rows[:top_n]: + print(f" {company:<40} {acc:>5.1f} {correct:>8} {halluc:>6} {n:>4}") + + +def print_error_cases(results: List[Dict[str, Any]], max_show: int = 20) -> None: + """Print detailed listing of error cases (hallucination + refusal). + + Args: + results: List of per-question result dicts. + max_show: Maximum number of error cases to display. + """ + errors = [r for r in results if r.get("classification") != "correct"] + if not errors: + print("\n=== Error Cases ===\n None — perfect score!") + return + + print(f"\n=== Error Cases ({len(errors)} total, showing up to {max_show}) ===\n") + + for i, r in enumerate(errors[:max_show], 1): + fb_id = r.get("financebench_id", "?") + cls = r.get("classification", "?") + question = r.get("question", "")[:100] + pred = r.get("prediction", "")[:80] + gold = r.get("gold_answer", "")[:80] + company = r.get("company", "") + em = r.get("em", False) + f1 = r.get("f1", 0.0) + + print(f" [{i:>2}] {fb_id} [{cls.upper()}]") + print(f" Company: {company}") + print(f" Question: {question}{'...' if len(r.get('question', '')) > 100 else ''}") + print(f" Predicted: {pred}{'...' if len(r.get('prediction', '')) > 80 else ''}") + print(f" Gold: {gold}{'...' if len(r.get('gold_answer', '')) > 80 else ''}") + print(f" EM={em} F1={f1:.3f}") + if r.get("error"): + print(f" Error: {r['error'][:120]}") + print() + + if len(errors) > max_show: + print(f" ... and {len(errors) - max_show} more error(s) not shown.\n") + + +def print_comparison_with_sota(metrics: Dict[str, Any]) -> None: + """Compare with published SOTA results on FinanceBench. + + Reference baselines from the FinanceBench leaderboard and recent papers. + """ + print("\n=== Comparison with SOTA ===\n") + header = f" {'System':<30} {'Accuracy':>10} {'Coverage':>10}" + print(header) + print(" " + "-" * (len(header) - 2)) + print(f" {'Mafin 2.5 (SOTA)':<30} {'98.7%':>10} {'100%':>10}") + print(f" {'Fintool':<30} {'98.0%':>10} {'66.7%':>10}") + print(f" {'Quantly':<30} {'94.0%':>10} {'100%':>10}") + print(f" {'GPT-4 (zero-shot)':<30} {'29.3%':>10} {'100%':>10}") + + acc = metrics.get("accuracy", 0) + n = metrics.get("n", 0) + coverage = min(100.0, n / 150.0 * 100) + print(f" {'Sirchmunk (This Run)':<30} {f'{acc:.1f}%':>10} {f'{coverage:.0f}%':>10}") + + # Show Judge Accuracy in SOTA table if available + judge_acc = metrics.get("llm_judge_accuracy") + if judge_acc is not None: + print(f" {'Sirchmunk (Judge Acc)':<30} {f'{judge_acc:.1f}%':>10} {f'{coverage:.0f}%':>10}") + + print(f"\n (This run evaluated {n} questions)") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + """Parse CLI arguments and generate a full analysis report.""" + parser = argparse.ArgumentParser( + description="Analyze FinanceBench benchmark results from a JSONL file", + ) + parser.add_argument( + "results_file", + help="Path to the results JSONL file produced by run_benchmark.py", + ) + parser.add_argument( + "--max-errors", + type=int, + default=20, + help="Maximum number of error cases to display (default: 20)", + ) + parser.add_argument( + "--top-companies", + type=int, + default=10, + help="Number of worst-performing companies to show (default: 10)", + ) + args = parser.parse_args() + + # Load + results = load_results(args.results_file) + if not results: + print("ERROR: no results loaded.", file=sys.stderr) + sys.exit(1) + + # Compute metrics + metrics = compute_metrics(results) + + # --- Overall summary --- + n = metrics.get("n", 0) + acc = metrics.get("accuracy", 0) + hallu = metrics.get("hallucination_rate", 0) + refuse = metrics.get("refusal_rate", 0) + avg_em = metrics.get("avg_em", 0) + avg_f1 = metrics.get("avg_f1", 0) + ev_recall = metrics.get("evidence_recall") + avg_latency = metrics.get("avg_latency", 0) + + print(f"\n{'=' * 60}") + print(f" FinanceBench Analysis ({n} questions)") + print(f"{'=' * 60}") + print(f" Accuracy: {acc:.1f}%") + print(f" Hallucination Rate: {hallu:.1f}%") + print(f" Refusal Rate: {refuse:.1f}%") + print(f" Avg EM: {avg_em:.3f}") + print(f" Avg F1: {avg_f1:.3f}") + if metrics.get("avg_evidence_recall") is not None: + print(f" Evidence Recall: {metrics['avg_evidence_recall']:.3f}") + else: + print(f" Evidence Recall: N/A (page-level telemetry unavailable)") + print(f" Avg Latency: {avg_latency:.1f}s") + + # LLM Judge independent metrics + if metrics.get("llm_judge_accuracy") is not None: + print(f"\n --- LLM Judge (Independent Evaluation) ---") + print(f" Judge Accuracy: {metrics['llm_judge_accuracy']:.1f}%") + print(f" Judge Correct: {metrics['llm_judge_correct']}/{metrics['llm_judge_count']}") + + # --- Breakdowns --- + if "by_question_type" in metrics: + print_breakdown("Question Type", metrics["by_question_type"]) + + if "by_question_reasoning" in metrics: + print_breakdown("Question Reasoning", metrics["by_question_reasoning"]) + + # --- Per-company breakdown (worst performers) --- + print_company_breakdown(results, top_n=args.top_companies) + + # --- Error cases --- + print_error_cases(results, max_show=args.max_errors) + + # --- Judge-Rule Discrepancies --- + discrepancies = [r for r in results + if r.get("llm_judge_correct") is not None + and r.get("classification") != "correct" + and r.get("llm_judge_correct") is True] + if discrepancies: + print(f"\n=== Judge-Rule Discrepancies ({len(discrepancies)} cases) ===") + print(" (Cases where LLM Judge says correct but EM/F1 says wrong)") + for r in discrepancies[:10]: + print(f" {r.get('financebench_id', 'N/A')}: pred='{r.get('prediction', '')[:50]}' gold='{r.get('gold_answer', '')[:50]}'") + print(f" classification={r.get('classification')}, judge_reasoning={r.get('llm_judge_reasoning', '')[:80]}") + if len(discrepancies) > 10: + print(f" ... and {len(discrepancies) - 10} more discrepancy(ies) not shown.") + + # --- SOTA comparison --- + print_comparison_with_sota(metrics) + + print(f"\n{'=' * 60}") + print(f" Source: {args.results_file}") + print(f"{'=' * 60}\n") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/financebench/config.py b/benchmarks/financebench/config.py new file mode 100644 index 0000000..68fe2a1 --- /dev/null +++ b/benchmarks/financebench/config.py @@ -0,0 +1,130 @@ +"""FinanceBench benchmark configuration.""" +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path + + +def _parse_env_file(path: str) -> dict[str, str]: + """Parse a .env file into a dict, handling comments, blank lines, and quotes.""" + result: dict[str, str] = {} + p = Path(path) + if not p.exists(): + return result + for line in p.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" not in line: + continue + k, v = line.split("=", 1) + v = v.strip() + # Strip surrounding quotes + if len(v) >= 2 and v[0] == v[-1] and v[0] in ('"', "'"): + v = v[1:-1] + result[k.strip()] = v + return result + + +@dataclass +class FinanceBenchConfig: + """All settings for a FinanceBench evaluation run.""" + + # LLM + llm_api_key: str = "" + llm_base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1" + llm_model: str = "qwen3.5-plus" + llm_timeout: int = 120 + + # Data paths + data_dir: str = "./data" + pdf_dir: str = "./data/pdfs" + output_dir: str = "./output" + + # Dataset + limit: int = 0 # 0 = all 150 + seed: int = 42 + + # Search + mode: str = "FAST" + top_k_files: int = 5 + max_token_budget: int = 128000 + enable_dir_scan: bool = True + + # Evaluation + eval_mode: str = "singleDoc" # singleDoc / sharedCorpus + enable_llm_judge: bool = True # LLM Judge drives Accuracy + Coverage evaluation + + # Concurrency + max_concurrent: int = 3 + request_delay: float = 0.5 + + # Experiment isolation + work_path: str = "./.work" # Isolated workspace for this experiment + + @classmethod + def from_env(cls, env_path: str = ".env.financebench") -> "FinanceBenchConfig": + """Load config with layer inheritance. + + Priority (highest to lowest): + 1. Experiment .env (.env.financebench) + 2. Platform .env (/.env, if exists) + 3. os.environ + 4. Dataclass defaults + """ + # Step 0: Pre-read experiment env to determine work_path + experiment_vars = _parse_env_file(env_path) + work_path = experiment_vars.get( + "FB_WORK_PATH", os.environ.get("FB_WORK_PATH", "./.work") + ) + + # Step 1: Load platform-level env (/.env) + platform_env_path = Path(work_path) / ".env" + platform_vars = _parse_env_file(str(platform_env_path)) + + # Step 2: Merge — experiment > platform > os.environ > defaults + merged = {**platform_vars, **experiment_vars} + + def _get(key: str, default: str = "") -> str: + return merged.get(key, os.environ.get(key, default)) + + def _bool(key: str, default: bool = False) -> bool: + v = _get(key, str(default)).lower() + return v in ("true", "1", "yes") + + def _int(key: str, default: int = 0) -> int: + try: + return int(_get(key, str(default))) + except (ValueError, TypeError): + return default + + def _float(key: str, default: float = 0.0) -> float: + try: + return float(_get(key, str(default))) + except (ValueError, TypeError): + return default + + return cls( + llm_api_key=_get("LLM_API_KEY"), + llm_base_url=_get( + "LLM_BASE_URL", + "https://dashscope.aliyuncs.com/compatible-mode/v1", + ), + llm_model=_get("LLM_MODEL_NAME", "qwen3.5-plus"), + llm_timeout=_int("LLM_TIMEOUT", 120), + data_dir=_get("FB_DATA_DIR", "./data"), + pdf_dir=_get("FB_PDF_DIR", "./data/pdfs"), + output_dir=_get("FB_OUTPUT_DIR", "./output"), + limit=_int("FB_LIMIT", 0), + seed=_int("FB_SEED", 42), + mode=_get("FB_MODE", "FAST"), + top_k_files=_int("FB_TOP_K_FILES", 5), + max_token_budget=_int("FB_MAX_TOKEN_BUDGET", 128000), + enable_dir_scan=_bool("FB_ENABLE_DIR_SCAN", True), + eval_mode=_get("FB_EVAL_MODE", "singleDoc"), + enable_llm_judge=_bool("FB_ENABLE_LLM_JUDGE", True), + max_concurrent=_int("FB_MAX_CONCURRENT", 3), + request_delay=_float("FB_REQUEST_DELAY", 0.5), + work_path=work_path, + ) diff --git a/benchmarks/financebench/data_loader.py b/benchmarks/financebench/data_loader.py new file mode 100644 index 0000000..7770865 --- /dev/null +++ b/benchmarks/financebench/data_loader.py @@ -0,0 +1,108 @@ +"""FinanceBench dataset loader.""" +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple + + +class FinanceBenchLoader: + """Load and validate FinanceBench JSONL data. + + Expects: + - ``data_dir/financebench_open_source.jsonl`` – 150 QA rows + - ``data_dir/financebench_document_information.jsonl`` – doc metadata (optional) + - ``pdf_dir/`` – corpus of 41 SEC-filing PDFs named by ``doc_name`` + """ + + _QUESTIONS_FILE = "financebench_open_source.jsonl" + _DOC_INFO_FILE = "financebench_document_information.jsonl" + + def __init__(self, data_dir: str, pdf_dir: str) -> None: + self._data_dir = Path(data_dir) + self._pdf_dir = Path(pdf_dir) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def load_questions(self) -> List[Dict[str, Any]]: + """Load the 150 open-source questions from JSONL. + + Raises: + FileNotFoundError: If the questions file is missing. + """ + path = self._data_dir / self._QUESTIONS_FILE + if not path.exists(): + raise FileNotFoundError(f"Questions file not found: {path}") + items: list[dict] = [] + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + items.append(json.loads(line)) + return items + + def load_doc_info(self) -> Dict[str, Dict[str, Any]]: + """Load document metadata, keyed by ``doc_name``. + + Returns an empty dict when the file is absent (it is optional). + """ + path = self._data_dir / self._DOC_INFO_FILE + if not path.exists(): + return {} + result: dict[str, dict] = {} + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + obj = json.loads(line) + doc_name = obj.get("doc_name", "") + if doc_name: + result[doc_name] = obj + return result + + def get_pdf_path(self, doc_name: str) -> Optional[str]: + """Resolve *doc_name* to a PDF file path. + + Resolution order: + 1. ``/.pdf`` + 2. ``/`` (file with no extension) + 3. Case-insensitive stem match across ``pdf_dir`` + """ + candidates = [ + self._pdf_dir / f"{doc_name}.pdf", + self._pdf_dir / doc_name, + ] + for c in candidates: + if c.exists(): + return str(c) + # Case-insensitive fallback + if self._pdf_dir.exists(): + lower = doc_name.lower() + for f in self._pdf_dir.iterdir(): + if f.stem.lower() == lower: + return str(f) + return None + + def get_unique_docs(self, questions: List[Dict[str, Any]]) -> Set[str]: + """Extract the unique set of ``doc_name`` values from *questions*.""" + return {q["doc_name"] for q in questions if "doc_name" in q} + + def validate_corpus( + self, questions: List[Dict[str, Any]] + ) -> Tuple[int, List[str]]: + """Check PDF availability for all referenced documents. + + Returns: + A tuple of ``(found_count, missing_doc_names)``. + """ + docs = self.get_unique_docs(questions) + missing: list[str] = [] + found = 0 + for doc in sorted(docs): + if self.get_pdf_path(doc): + found += 1 + else: + missing.append(doc) + return found, missing diff --git a/benchmarks/financebench/evaluate.py b/benchmarks/financebench/evaluate.py new file mode 100644 index 0000000..d9614f3 --- /dev/null +++ b/benchmarks/financebench/evaluate.py @@ -0,0 +1,209 @@ +"""FinanceBench evaluation metrics — LLM Judge driven. + +All correctness evaluation (Accuracy, Coverage) is driven by the LLM Judge. +This module aggregates per-question judge results into benchmark-level metrics. + +The ``normalize_answer`` helper is retained for quick short-circuit checks +inside the judge (exact-match bypass before calling the LLM). +""" +from __future__ import annotations + +import re +from collections import defaultdict +from typing import Any, Dict, List + +# ------------------------------------------------------------------ +# Constants +# ------------------------------------------------------------------ + +# Markdown / wrapper patterns compiled once +_RE_BOLD = re.compile(r"\*\*(.+?)\*\*") +_RE_ITALIC = re.compile(r"\*(.+?)\*") +_RE_QUOTES = re.compile(r'^["\u201c\u201d\']+|["\u201c\u201d\']+$') +_RE_ANSWER_PREFIX = re.compile( + r"^(the\s+(short\s+)?answer\s+is\s*:?\s*|answer\s*:\s*|short\s+answer\s*:\s*)", + re.IGNORECASE, +) +# Financial value helpers +_RE_DOLLAR = re.compile(r"^\$\s*") +_RE_THOUSAND_SEP = re.compile(r",(\d{3})") +_RE_TRAILING_ZEROS = re.compile(r"\.0+$") + + +# ------------------------------------------------------------------ +# Normalisation +# ------------------------------------------------------------------ + + +def normalize_answer(answer: str) -> str: + """Normalise an answer string for comparison. + + Steps: + 1. Strip Markdown bold / italic. + 2. Strip surrounding quotes. + 3. Strip trailing punctuation (``.``, ``:``). + 4. Remove common LLM wrapper phrases. + 5. Financial value normalisation (currency, commas, trailing zeros). + 6. Lowercase. + """ + s = answer.strip() + if not s: + return "" + + # 1. Markdown + s = _RE_BOLD.sub(r"\1", s) + s = _RE_ITALIC.sub(r"\1", s) + + # 2. Quotes + s = _RE_QUOTES.sub("", s).strip() + + # 3. Trailing punctuation + s = s.rstrip(".:") + + # 4. Wrapper phrases + s = _RE_ANSWER_PREFIX.sub("", s).strip() + + # 5. Financial normalisation + s = _normalize_financial_value(s) + + # 6. Lowercase + return s.lower().strip() + + +def _normalize_financial_value(text: str) -> str: + """Normalise financial figures for robust comparison. + + - ``$1,577.00`` → ``1577`` + - ``15.3%`` → ``15.3%`` + - ``$1577`` → ``1577`` + - ``1,577`` → ``1577`` + - ``($500)`` → ``-500`` + - ``-$500`` → ``-500`` + """ + s = text.strip() + + # Handle accounting bracket notation for negatives: ($500) → -$500 + if s.startswith("(") and s.endswith(")"): + s = "-" + s[1:-1] + + # Handle negative sign: remember it, strip it for processing + negative = False + if s.startswith("-"): + negative = True + s = s[1:] + + # Detect if value looks numeric (possibly with $ / % / commas) + stripped_for_check = _RE_DOLLAR.sub("", s) + stripped_for_check = stripped_for_check.replace(",", "").rstrip("%").strip() + try: + float(stripped_for_check) + except ValueError: + # Not a numeric value – restore negative sign and return as-is + return ("-" + s) if negative else s + + # Remove dollar sign + s = _RE_DOLLAR.sub("", s) + + # Remember and temporarily strip percentage + has_pct = s.endswith("%") + if has_pct: + s = s[:-1].strip() + + # Remove thousand-separator commas + s = s.replace(",", "") + + # Remove trailing decimal zeros: 1577.00 → 1577, 15.30 → 15.3 + if "." in s: + s = s.rstrip("0").rstrip(".") + + # Re-attach percentage + if has_pct: + s = s + "%" + + # Re-attach negative sign + if negative and not s.startswith("-"): + s = "-" + s + + return s + + +# ------------------------------------------------------------------ +# Aggregate metrics +# ------------------------------------------------------------------ + + +def compute_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate per-question results into benchmark-level metrics. + + All correctness evaluation is driven by LLM Judge results stored in + each result dict (``judge_correct``, ``coverage``). + + Returns a dict with overall stats plus breakdown by *question_type*. + """ + n = len(results) + if n == 0: + return {"n": 0} + + # --- Accuracy (Judge) --- + judge_correct = sum(1 for r in results if r.get("judge_correct")) + + # --- Coverage (Judge) --- + coverage_true = sum(1 for r in results if r.get("coverage")) + + # --- Latency --- + latencies = [r["elapsed"] for r in results if "elapsed" in r] + avg_latency = sum(latencies) / len(latencies) if latencies else 0.0 + total_time = sum(latencies) + + # --- Token usage --- + search_tokens = sum( + r.get("telemetry", {}).get("total_tokens", 0) for r in results + ) + judge_tokens = sum(r.get("judge_tokens", 0) for r in results) + total_tokens = search_tokens + judge_tokens + avg_tokens_per_question = total_tokens / n if n else 0 + + overall: Dict[str, Any] = { + "n": n, + "accuracy": round(judge_correct / n * 100, 2), + "coverage": round(coverage_true / n * 100, 2), + "avg_latency": round(avg_latency, 2), + "total_time_seconds": round(total_time, 2), + "token_usage": { + "total_tokens": total_tokens, + "search_tokens": search_tokens, + "judge_tokens": judge_tokens, + "avg_tokens_per_question": round(avg_tokens_per_question, 1), + }, + "judge_correct": judge_correct, + "coverage_true": coverage_true, + "by_question_type": _breakdown(results, "question_type"), + } + + return overall + + +def _breakdown( + results: List[Dict[str, Any]], key: str +) -> Dict[str, Dict[str, Any]]: + """Compute per-group accuracy / coverage breakdown.""" + groups: dict[str, list[dict]] = defaultdict(list) + for r in results: + group = r.get(key) or "unknown" + groups[group].append(r) + + out: dict[str, dict] = {} + for group, items in sorted( + groups.items(), key=lambda x: (x[0] is None, x[0] or "") + ): + g_n = len(items) + g_correct = sum(1 for r in items if r.get("judge_correct")) + g_coverage = sum(1 for r in items if r.get("coverage")) + out[group] = { + "n": g_n, + "accuracy": round(g_correct / g_n * 100, 2) if g_n else 0.0, + "coverage": round(g_coverage / g_n * 100, 2) if g_n else 0.0, + "judge_count": g_n, + "judge_correct": g_correct, + } + return out diff --git a/benchmarks/financebench/judge.py b/benchmarks/financebench/judge.py new file mode 100644 index 0000000..1e5e1ca --- /dev/null +++ b/benchmarks/financebench/judge.py @@ -0,0 +1,609 @@ +"""LLM-based judge for FinanceBench evaluation. + +The judge drives **all** evaluation decisions: +- **Accuracy**: whether the prediction is semantically equivalent to the gold answer. +- **Coverage**: whether the prediction contains any information relevant to the question. + +This replaces the previous EM/F1 rule-driven pipeline with a single LLM-based +evaluation authority, providing more nuanced correctness signals for financial QA. +""" + +from __future__ import annotations + +import json +import logging +import re +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + + +_JUDGE_PROMPT = """\ +You are an expert financial analyst and auditor evaluating answer correctness \ +with **zero tolerance for numerical or factual errors**. + +Question: {question} +Gold Answer: {gold} +Model Prediction: {prediction} + +Task: Determine if the model's prediction is **semantically equivalent** \ +to the gold answer in the context of this financial question. + +═══════════════════════════════════════════════ +EQUIVALENT — only when ALL of the following hold: +═══════════════════════════════════════════════ + +1. **Numerical precision (ZERO TOLERANCE)**: + - Values must be mathematically identical after unit conversion. + - $1.5B = $1,500M = $1,500,000K = $1,500,000,000 ✓ + - $1,577 ≠ $1,580 ✗ (rounding is NOT acceptable) + - 15.3% = 15.30% = 0.153 ✓ but 15.3% ≠ 15% ✗ + - $1.5M ≠ $1.5B ✗ (unit mismatch is a critical error) + +2. **Negative / bracket notation**: + - ($500) = -$500 = -500 ✓ + - ($500) ≠ $500 ✗ (sign matters) + +3. **Time period / fiscal year**: + - FY2018 = fiscal year 2018 = 2018 ✓ + - FY2018 ≠ FY2019 ✗ (different fiscal year — NEVER equivalent) + - Q3 2019 ≠ Q4 2019 ✗ (different quarter) + - "year ended December 2018" = FY2018 ✓ + +4. **Currency formatting**: + - $1,577.00 = $1577 = 1577 ✓ (same value, format differs) + +5. **Financial term equivalences (accepted)**: + - net income = net profit ✓ + - CAPEX = capital expenditure ✓ + - EPS = earnings per share ✓ + - EBITDA = earnings before interest, taxes, depreciation and amortization ✓ + - YoY = year-over-year ✓ + - COGS = cost of goods sold ✓ + - D&A = depreciation and amortization ✓ + +6. **Financial term distinctions (NOT interchangeable)**: + - revenue ≠ net revenue ≠ gross revenue (unless context is clear) + - operating income ≠ net income + - gross profit ≠ net profit + - total assets ≠ net assets + +7. **Prediction with extra context**: + - If prediction contains the correct answer with additional supporting \ + detail, treat as equivalent (e.g., "Revenue was $1,577M in FY2018" \ + vs "$1,577M" — equivalent, provided the value is correct). + +═══════════════════════════════════════════════ +NOT EQUIVALENT — if ANY of the following hold: +═══════════════════════════════════════════════ + +1. Different numerical values (even slightly: $1,577 ≠ $1,580) +2. Different time periods or fiscal years +3. Different companies or entities +4. Opposite trend direction (increased ≠ decreased, growth ≠ decline) +5. Unit mismatch ($1.5M ≠ $1.5B) +6. Missing or wrong sign (positive ≠ negative) +7. Prediction is vague or hedging where gold is precise +8. Prediction is a refusal or states it cannot find the answer +9. Near-approximate values that are not mathematically equal after unit conversion + +═══════════════════════════════════════════════ +CONSERVATIVE JUDGMENT POLICY +═══════════════════════════════════════════════ + +- **When in doubt, judge as NOT equivalent.** Financial accuracy demands \ + precision; a false positive (incorrectly marking wrong answer as correct) \ + is far worse than a false negative. +- If you are less than 80% confident the answers are equivalent, \ + judge as NOT equivalent. +- Set confidence to reflect your actual certainty (0.0 = no idea, \ + 1.0 = absolutely certain). + +═══════════════════════════════════════════════ +FEW-SHOT EXAMPLES +═══════════════════════════════════════════════ + +Example 1 — EQUIVALENT (format difference): + Gold: "$1,577" | Prediction: "$1,577.00 million" + → {{"equivalent": true, "confidence": 0.95, "reasoning": "Same value $1,577M, trailing zeros are formatting."}} + +Example 2 — EQUIVALENT (abbreviation): + Gold: "$1.5 billion" | Prediction: "$1,500M" + → {{"equivalent": true, "confidence": 0.97, "reasoning": "$1.5B = $1,500M, correct unit conversion."}} + +Example 3 — NOT EQUIVALENT (different value): + Gold: "$1,577" | Prediction: "$1,580" + → {{"equivalent": false, "confidence": 0.99, "reasoning": "Values differ: 1577 ≠ 1580. No rounding tolerance."}} + +Example 4 — NOT EQUIVALENT (different fiscal year): + Gold: "FY2018" | Prediction: "FY2019" + → {{"equivalent": false, "confidence": 1.0, "reasoning": "Different fiscal years."}} + +Example 5 — NOT EQUIVALENT (unit mismatch): + Gold: "$1.5 million" | Prediction: "$1.5 billion" + → {{"equivalent": false, "confidence": 1.0, "reasoning": "Unit mismatch: million ≠ billion."}} + +Example 6 — EQUIVALENT (negative notation): + Gold: "-$500" | Prediction: "($500)" + → {{"equivalent": true, "confidence": 0.98, "reasoning": "Same negative value, bracket = negative."}} + +Respond ONLY with a JSON object (no markdown, no extra text): +{{"equivalent": true or false, "confidence": 0.0 to 1.0, "reasoning": "brief explanation"}}""" + + +# Refusal detection phrases (subset for quick judge-side check) +_REFUSAL_INDICATORS: frozenset[str] = frozenset( + { + "i cannot", + "i can't", + "unable to", + "not able to", + "i don't know", + "i do not know", + "unknown", + "no results found", + "cannot determine", + "insufficient data", + "data not found", + "could not find", + "couldn't find", + "unable to determine", + "unable to find", + } +) + + +class FinanceBenchLLMJudge: + """LLM-based judge driving all FinanceBench evaluation. + + Provides two evaluation axes: + - ``judge()``: semantic equivalence (Accuracy). + - ``judge_coverage()``: information relevance (Coverage). + + Token usage from every LLM call is tracked and returned. + """ + + _CONFIDENCE_THRESHOLD: float = 0.7 + _MAX_RETRIES: int = 2 + + # Coverage evaluation prompt + _COVERAGE_PROMPT: str = """\ +You are evaluating whether a system's response contains ANY useful information \ +relevant to the given financial question. + +Question: {question} +System Response: {prediction} + +Task: Determine if the response contains relevant, useful information. + +═══════════════════════════════════════════════ +HAS COVERAGE (has_coverage = true) — when ANY of: +═══════════════════════════════════════════════ +1. Contains specific financial data (dollar amounts, percentages, ratios) +2. Contains relevant factual statements about the company or topic +3. Contains partial but concrete information related to the question +4. Provides a direct answer (even if potentially incorrect) + +═══════════════════════════════════════════════ +NO COVERAGE (has_coverage = false) — when ALL of: +═══════════════════════════════════════════════ +1. Response is a refusal ("I cannot", "No results found", etc.) +2. Response contains no concrete data related to the question +3. Response is empty, purely apologetic, or only contains generic filler + +Respond ONLY with a JSON object (no markdown, no extra text): +{{"has_coverage": true or false, "confidence": 0.0 to 1.0, "reasoning": "brief explanation"}}""" + + def __init__(self, llm: Any) -> None: + self._llm = llm + self._cache: Dict[tuple, Dict[str, Any]] = {} + self._total_tokens_used: int = 0 + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def judge( + self, + prediction: str, + gold_answer: str, + question: str = "", + ) -> Dict[str, Any]: + """Judge whether prediction is semantically equivalent to gold. + + Args: + prediction: Model's answer text. + gold_answer: Ground-truth answer text. + question: The original question (for context). + + Returns: + { + "equivalent": bool, + "confidence": float (0-1), + "reasoning": str, + "cached": bool, + "error": Optional[str], + "tokens_used": int, + } + """ + # --- Refusal short-circuit (saves LLM call) --- + if self._is_refusal(prediction): + return { + "equivalent": False, + "confidence": 1.0, + "reasoning": "Prediction is a refusal — skipped LLM judge.", + "cached": False, + "error": None, + "tokens_used": 0, + } + + # --- Quick exact-match shortcut --- + from evaluate import normalize_answer + + if normalize_answer(prediction) == normalize_answer(gold_answer): + return { + "equivalent": True, + "confidence": 1.0, + "reasoning": "Normalized exact match", + "cached": False, + "error": None, + "tokens_used": 0, + } + + # --- Check cache (key includes question for context-sensitivity) --- + cache_key = ( + question.strip().lower(), + prediction.strip().lower(), + gold_answer.strip().lower(), + ) + if cache_key in self._cache: + result = dict(self._cache[cache_key]) + result["cached"] = True + return result + + # --- Call LLM with retry --- + prompt = _JUDGE_PROMPT.format( + question=question or "N/A", + gold=gold_answer, + prediction=prediction, + ) + + result: Dict[str, Any] | None = None + last_error: str | None = None + tokens_used: int = 0 + + for attempt in range(1, self._MAX_RETRIES + 1): + try: + resp = await self._llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=False, + ) + tokens_used = self._extract_tokens(resp) + raw = resp.content.strip() + result = self._parse_response(raw) + if result.get("error") is None: + break # success + last_error = result.get("error") + except Exception as e: + last_error = str(e) + logger.warning( + "LLM Judge call failed (attempt %d/%d): %s", + attempt, + self._MAX_RETRIES, + e, + ) + result = None + + if result is None or result.get("error") is not None: + result = { + "equivalent": False, + "confidence": 0.0, + "reasoning": f"Judge error after {self._MAX_RETRIES} attempts: {last_error}", + "error": last_error, + } + + # --- Apply confidence threshold (conservative) --- + if ( + result.get("error") is None + and result["equivalent"] + and result["confidence"] < self._CONFIDENCE_THRESHOLD + ): + result["equivalent"] = False + result["reasoning"] = ( + f"Overridden to NOT equivalent: confidence " + f"{result['confidence']:.2f} < threshold " + f"{self._CONFIDENCE_THRESHOLD} — conservative policy. " + f"Original reasoning: {result['reasoning']}" + ) + + result.setdefault("cached", False) + result.setdefault("error", None) + result["tokens_used"] = tokens_used + self._total_tokens_used += tokens_used + + # Cache successful results only + if result["error"] is None: + self._cache[cache_key] = { + k: v for k, v in result.items() if k != "cached" + } + + return result + + # ------------------------------------------------------------------ + # Parsing + # ------------------------------------------------------------------ + + def _parse_response(self, raw: str) -> Dict[str, Any]: + """Parse LLM JSON response with robust fallback heuristics.""" + # --- Try direct JSON parse --- + parsed = self._try_parse_json(raw) + if parsed is not None: + return self._validated_result(parsed, raw) + + # --- Fallback: keyword detection (conservative) --- + lower = raw.lower() + + # Look for explicit true/false patterns with word boundaries + true_match = re.search( + r'"equivalent"\s*:\s*true\b', lower + ) + false_match = re.search( + r'"equivalent"\s*:\s*false\b', lower + ) + + if false_match and not true_match: + return { + "equivalent": False, + "confidence": 0.5, + "reasoning": f"Keyword fallback (NOT equivalent): {raw[:200]}", + } + elif true_match and not false_match: + # Conservative: lower confidence for keyword-only parse + return { + "equivalent": True, + "confidence": 0.5, + "reasoning": f"Keyword fallback (equivalent): {raw[:200]}", + } + + # --- Cannot parse → conservative default --- + logger.warning("Cannot parse judge response: %s", raw[:200]) + return { + "equivalent": False, + "confidence": 0.0, + "reasoning": f"Unparseable response: {raw[:200]}", + "error": "parse_error", + } + + def _try_parse_json(self, raw: str) -> Optional[Dict[str, Any]]: + """Attempt multiple JSON extraction strategies.""" + strategies = [ + raw.strip(), + # Strip markdown code fences + re.sub(r"```(?:json)?\s*\n?", "", raw).strip().rstrip("`").strip(), + # Extract first {...} block + self._extract_json_block(raw), + ] + + for text in strategies: + if not text: + continue + # Fix common LLM JSON quirks + text = self._fix_json_quirks(text) + try: + return json.loads(text) + except (json.JSONDecodeError, ValueError): + continue + return None + + @staticmethod + def _extract_json_block(raw: str) -> Optional[str]: + """Extract the first {...} JSON object from raw text.""" + match = re.search(r"\{[^{}]*\}", raw, re.DOTALL) + return match.group(0) if match else None + + @staticmethod + def _fix_json_quirks(text: str) -> str: + """Fix common non-standard JSON from LLMs.""" + # Replace single quotes with double quotes (basic heuristic) + # Only if the text doesn't already have double quotes for keys + if "'" in text and '"' not in text: + text = text.replace("'", '"') + # Remove trailing commas before closing braces + text = re.sub(r",\s*}", "}", text) + text = re.sub(r",\s*]", "]", text) + return text + + def _validated_result( + self, obj: Dict[str, Any], raw: str + ) -> Dict[str, Any]: + """Build a validated result dict from parsed JSON, clamping values.""" + equivalent = bool(obj.get("equivalent", False)) + + # Clamp confidence to [0.0, 1.0] + try: + confidence = float(obj.get("confidence", 0.0)) + except (ValueError, TypeError): + confidence = 0.0 + confidence = max(0.0, min(1.0, confidence)) + + reasoning = str(obj.get("reasoning", "")) + + return { + "equivalent": equivalent, + "confidence": confidence, + "reasoning": reasoning, + } + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _is_refusal(text: str) -> bool: + """Quick check whether *text* looks like a refusal / non-answer. + + When the text contains an explicit ``**Answer: xxx**`` marker, + only the answer value is checked for refusal phrases so that + reasoning text containing phrases like "insufficient data" (as + analytical context) does not trigger a false positive. + """ + if not text or not text.strip(): + return True + lower = text.strip().lower() + if lower in ("unknown", "n/a", "none", ""): + return True + + # If there is an explicit **Answer: xxx** marker, only check that value + answer_match = re.search(r'\*\*answer:\s*(.+?)\*\*', lower) + if answer_match: + answer_val = answer_match.group(1).strip() + for phrase in _REFUSAL_INDICATORS: + if phrase in answer_val: + return True + return False + + # No structured answer marker — check the leading portion only + check_region = lower[:300] + for phrase in _REFUSAL_INDICATORS: + if phrase in check_region: + return True + return False + + async def judge_coverage( + self, + prediction: str, + question: str, + ) -> Dict[str, Any]: + """Evaluate whether *prediction* contains relevant information for *question*. + + Returns: + { + "has_coverage": bool, + "confidence": float (0-1), + "reasoning": str, + "tokens_used": int, + "error": Optional[str], + } + """ + # --- Refusal short-circuit --- + if self._is_refusal(prediction): + return { + "has_coverage": False, + "confidence": 1.0, + "reasoning": "Explicit refusal detected.", + "tokens_used": 0, + "error": None, + } + + prompt = self._COVERAGE_PROMPT.format( + question=question or "N/A", + prediction=prediction[:4000], + ) + + result: Dict[str, Any] | None = None + last_error: str | None = None + tokens_used: int = 0 + + for attempt in range(1, self._MAX_RETRIES + 1): + try: + resp = await self._llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=False, + ) + tokens_used = self._extract_tokens(resp) + raw = resp.content.strip() + result = self._parse_coverage_response(raw) + if result.get("error") is None: + break + last_error = result.get("error") + except Exception as e: + last_error = str(e) + logger.warning( + "LLM Coverage judge failed (attempt %d/%d): %s", + attempt, + self._MAX_RETRIES, + e, + ) + result = None + + if result is None or result.get("error") is not None: + result = { + "has_coverage": False, + "confidence": 0.0, + "reasoning": f"Coverage judge error after {self._MAX_RETRIES} attempts: {last_error}", + "error": last_error, + } + + result.setdefault("error", None) + result["tokens_used"] = tokens_used + self._total_tokens_used += tokens_used + return result + + # ------------------------------------------------------------------ + # Coverage response parsing + # ------------------------------------------------------------------ + + def _parse_coverage_response(self, raw: str) -> Dict[str, Any]: + """Parse LLM JSON response for coverage evaluation.""" + parsed = self._try_parse_json(raw) + if parsed is not None: + has_coverage = bool(parsed.get("has_coverage", False)) + try: + confidence = float(parsed.get("confidence", 0.0)) + except (ValueError, TypeError): + confidence = 0.0 + confidence = max(0.0, min(1.0, confidence)) + reasoning = str(parsed.get("reasoning", "")) + return { + "has_coverage": has_coverage, + "confidence": confidence, + "reasoning": reasoning, + } + + # Fallback: keyword detection + lower = raw.lower() + true_match = re.search(r'"has_coverage"\s*:\s*true\b', lower) + false_match = re.search(r'"has_coverage"\s*:\s*false\b', lower) + + if false_match and not true_match: + return { + "has_coverage": False, + "confidence": 0.5, + "reasoning": f"Keyword fallback (no coverage): {raw[:200]}", + } + elif true_match and not false_match: + return { + "has_coverage": True, + "confidence": 0.5, + "reasoning": f"Keyword fallback (has coverage): {raw[:200]}", + } + + logger.warning("Cannot parse coverage response: %s", raw[:200]) + return { + "has_coverage": False, + "confidence": 0.0, + "reasoning": f"Unparseable response: {raw[:200]}", + "error": "parse_error", + } + + # ------------------------------------------------------------------ + # Token tracking + # ------------------------------------------------------------------ + + @staticmethod + def _extract_tokens(resp: Any) -> int: + """Extract total token count from an LLM response.""" + usage = getattr(resp, "usage", None) + if isinstance(usage, dict): + return int(usage.get("total_tokens", 0)) + return 0 + + @property + def total_tokens_used(self) -> int: + """Cumulative tokens consumed by all judge calls.""" + return self._total_tokens_used + + @property + def cache_size(self) -> int: + """Return the number of cached judge results.""" + return len(self._cache) diff --git a/benchmarks/financebench/run_benchmark.py b/benchmarks/financebench/run_benchmark.py new file mode 100644 index 0000000..183a6d3 --- /dev/null +++ b/benchmarks/financebench/run_benchmark.py @@ -0,0 +1,319 @@ +"""FinanceBench benchmark entry point. + +Usage: + cd benchmarks/financebench + python run_benchmark.py [--env .env.financebench] [--limit N] + +Examples: + # Run all 150 questions with default config + python run_benchmark.py + + # Run a quick sanity check with 10 questions + python run_benchmark.py --limit 10 + + # Use a custom .env file + python run_benchmark.py --env .env.custom --limit 20 +""" +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import random +import sys +import time +from datetime import datetime +from pathlib import Path +from typing import List + +from dotenv import load_dotenv + +from config import FinanceBenchConfig +from data_loader import FinanceBenchLoader +from evaluate import compute_metrics +from runner import run_batch + +# --------------------------------------------------------------------------- +# Tee stdout to log file +# --------------------------------------------------------------------------- + + +class _TeeWriter: + """Duplicate stdout to both terminal and a log file.""" + + def __init__(self, log_path: str) -> None: + self._terminal = sys.stdout + self._log = open(log_path, "w", encoding="utf-8") # noqa: SIM115 + + def write(self, msg: str) -> int: + self._terminal.write(msg) + self._log.write(msg) + return len(msg) + + def flush(self) -> None: + self._terminal.flush() + self._log.flush() + + def close(self) -> None: + self._log.close() + + # Let logging / other code check the stream capabilities + @property + def encoding(self) -> str: + return getattr(self._terminal, "encoding", "utf-8") + + def isatty(self) -> bool: + return False + + def fileno(self) -> int: + return self._terminal.fileno() + + +# --------------------------------------------------------------------------- +# Logging +# --------------------------------------------------------------------------- + + +def setup_logging(output_dir: str, ts: str | None = None) -> tuple[str, str]: + """Configure logging to file + console. + + Creates a timestamped log file under ``logs/`` (relative to *output_dir*'s + parent, i.e. the benchmark root directory). + + Returns: + Tuple of (absolute path to the log file, timestamp string). + """ + log_dir = Path("logs") + log_dir.mkdir(parents=True, exist_ok=True) + + if ts is None: + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + log_path = log_dir / f"benchmark_{ts}.log" + + root_logger = logging.getLogger("financebench") + root_logger.setLevel(logging.DEBUG) + + # File handler – DEBUG level, full detail + fh = logging.FileHandler(str(log_path), encoding="utf-8") + fh.setLevel(logging.DEBUG) + fh.setFormatter( + logging.Formatter( + "%(asctime)s %(name)-28s %(levelname)-7s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + + # Console handler – INFO level, concise + ch = logging.StreamHandler(sys.stdout) + ch.setLevel(logging.INFO) + ch.setFormatter( + logging.Formatter("%(asctime)s %(levelname)-7s %(message)s", datefmt="%H:%M:%S") + ) + + root_logger.addHandler(fh) + root_logger.addHandler(ch) + + return str(log_path.resolve()), ts + + +# --------------------------------------------------------------------------- +# Summary printing +# --------------------------------------------------------------------------- + + +def _print_summary( + results: List[dict], + metrics: dict, + total_time: float, + results_path: Path, + metrics_path: Path, + log_path: str, +) -> None: + """Print a human-readable run summary to stdout.""" + n = len(results) + acc = metrics.get("accuracy", 0) + cov = metrics.get("coverage", 0) + avg_latency = metrics.get("avg_latency", 0) + + token_usage = metrics.get("token_usage", {}) + total_tokens = token_usage.get("total_tokens", 0) + search_tokens = token_usage.get("search_tokens", 0) + judge_tokens = token_usage.get("judge_tokens", 0) + avg_tokens_q = token_usage.get("avg_tokens_per_question", 0) + + print("\n" + "=" * 60) + print(f"FinanceBench Results ({n} questions)") + print("=" * 60) + print(f" Accuracy (Judge): {acc:.1f}%") + print(f" Coverage (Judge): {cov:.1f}%") + print(f" Avg Latency: {avg_latency:.1f}s") + print(f" Total Time: {total_time:.1f}s") + + print(f"\n --- Token Usage ---") + print(f" Total Tokens: {total_tokens:>,}") + print(f" Search Tokens: {search_tokens:>,}") + print(f" Judge Tokens: {judge_tokens:>,}") + print(f" Avg per Question: {avg_tokens_q:>,.0f}") + + print(f"\n Results: {results_path}") + print(f" Metrics: {metrics_path}") + print(f" Log: {log_path}") + + # Breakdown by question_type + by_qt = metrics.get("by_question_type") + if by_qt: + print(f"\n {'Question Type':<28} {'Acc%':>6} {'Cover%':>7} {'N':>5}") + print(" " + "-" * 48) + for qt, m in sorted(by_qt.items()): + qt_acc = m.get("accuracy", 0) + qt_cov = m.get("coverage", 0) + qt_n = m.get("n", 0) + print(f" {qt:<28} {qt_acc:>5.1f} {qt_cov:>7.1f} {qt_n:>5}") + + print("=" * 60) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + """Parse CLI arguments, run the benchmark, and save results.""" + parser = argparse.ArgumentParser( + description="Run FinanceBench benchmark against Sirchmunk AgenticSearch", + ) + parser.add_argument( + "--env", + default=".env.financebench", + help="Path to .env config file (default: .env.financebench)", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Override FB_LIMIT — number of questions to evaluate", + ) + args = parser.parse_args() + + # Load .env into os.environ so SIRCHMUNK_* variables are visible globally + load_dotenv(args.env, override=True) + + # 1. Load config + cfg = FinanceBenchConfig.from_env(args.env) + if args.limit is not None: + cfg.limit = args.limit + + # 2. Setup logging + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + log_path, ts = setup_logging(cfg.output_dir, ts=ts) + logger = logging.getLogger("financebench") + + # 2b. Tee stdout → debug log so SEARCH_WIKI_DEBUG prints are captured + log_dir = Path("logs") + log_dir.mkdir(parents=True, exist_ok=True) + debug_log_path = log_dir / f"benchmark_{ts}_debug.log" + tee = _TeeWriter(str(debug_log_path)) + sys.stdout = tee + + # Print config source info + work_env = Path(cfg.work_path) / ".env" + logger.info("=" * 50) + logger.info("FinanceBench Configuration") + logger.info("=" * 50) + logger.info(" Experiment env : %s", args.env) + logger.info(" Platform env : %s (%s)", work_env, "found" if work_env.exists() else "not found") + logger.info(" Work path : %s", Path(cfg.work_path).resolve()) + logger.info(" LLM : %s @ %s", cfg.llm_model, cfg.llm_base_url) + logger.info(" Eval mode : %s", cfg.eval_mode) + logger.info(" Search mode : %s, Top-K: %d", cfg.mode, cfg.top_k_files) + logger.info(" LLM Judge : %s", "enabled" if cfg.enable_llm_judge else "disabled") + logger.info("=" * 50) + + # 3. Load data + loader = FinanceBenchLoader(cfg.data_dir, cfg.pdf_dir) + questions = loader.load_questions() + logger.info("Loaded %d questions from %s", len(questions), cfg.data_dir) + + # 4. Validate corpus + found, missing = loader.validate_corpus(questions) + logger.info("PDF corpus: %d found, %d missing", found, len(missing)) + if missing: + preview = missing[:10] + suffix = "..." if len(missing) > 10 else "" + logger.warning("Missing PDFs: %s%s", preview, suffix) + + # 5. Apply limit / seed + if cfg.limit > 0 and cfg.limit < len(questions): + random.seed(cfg.seed) + questions = random.sample(questions, cfg.limit) + logger.info("Sampled %d questions (seed=%d)", len(questions), cfg.seed) + + # 6. Print run config + logger.info( + "Config: mode=%s, eval_mode=%s, llm_judge=%s, concurrent=%d, model=%s", + cfg.mode, + cfg.eval_mode, + cfg.enable_llm_judge, + cfg.max_concurrent, + cfg.llm_model, + ) + + # 7. Run benchmark + t0 = time.time() + results = asyncio.run(run_batch(questions, cfg)) + total_time = time.time() - t0 + + # 8. Compute metrics + metrics = compute_metrics(results) + metrics["total_time_seconds"] = round(total_time, 2) + metrics["num_questions"] = len(questions) + metrics["config"] = { + "mode": cfg.mode, + "eval_mode": cfg.eval_mode, + "model": cfg.llm_model, + "top_k_files": cfg.top_k_files, + } + + # 9. Save results (JSONL) + metrics (JSON) + out_dir = Path(cfg.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + results_path = out_dir / f"results_{ts}.jsonl" + metrics_path = out_dir / f"metrics_{ts}.json" + + with open(results_path, "w", encoding="utf-8") as f: + for r in results: + f.write(json.dumps(r, ensure_ascii=False) + "\n") + + with open(metrics_path, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=2, ensure_ascii=False) + + logger.info("Results saved to %s", results_path) + logger.info("Metrics saved to %s", metrics_path) + + # 10. Print summary + _print_summary(results, metrics, total_time, results_path, metrics_path, log_path) + print(f" Debug log: {debug_log_path.resolve()}") + + # 11. Restore stdout + sys.stdout = tee._terminal + tee.close() + + +def _main_safe() -> None: + """Wrapper that guarantees stdout is restored even on exceptions.""" + try: + main() + except (KeyboardInterrupt, Exception): + # Restore stdout if tee was installed + if hasattr(sys.stdout, "_terminal"): + terminal = sys.stdout._terminal + sys.stdout.close() + sys.stdout = terminal + raise + + +if __name__ == "__main__": + _main_safe() diff --git a/benchmarks/financebench/runner.py b/benchmarks/financebench/runner.py new file mode 100644 index 0000000..7e2f115 --- /dev/null +++ b/benchmarks/financebench/runner.py @@ -0,0 +1,226 @@ +"""Run AgenticSearch on FinanceBench questions. + +Supports two evaluation modes: +- **singleDoc**: each question searches only its target PDF directory. +- **sharedCorpus**: all questions search the full PDF corpus. + +All evaluation (Accuracy + Coverage) is driven by LLM Judge. +""" +from __future__ import annotations + +import asyncio +import json as json_mod +import logging +import time +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +from config import FinanceBenchConfig +from data_loader import FinanceBenchLoader +from evaluate import compute_metrics + +logger = logging.getLogger("financebench.runner") + + +# ------------------------------------------------------------------ +# Single question execution +# ------------------------------------------------------------------ + + +async def run_single( + entry: Dict[str, Any], + loader: FinanceBenchLoader, + searcher: Any, + llm: Any, + cfg: FinanceBenchConfig, + semaphore: asyncio.Semaphore, + judge: Any = None, +) -> Dict[str, Any]: + """Execute one FinanceBench question end-to-end.""" + fb_id = entry.get("financebench_id", "") + question = entry["question"] + gold = entry.get("answer", "") + + async with semaphore: + t0 = time.time() + error: str | None = None + raw_answer = "" + telemetry: dict[str, Any] = {} + + try: + # Determine search paths based on eval mode + if cfg.eval_mode == "singleDoc": + pdf_path = loader.get_pdf_path(entry.get("doc_name", "")) + if pdf_path: + search_paths = [pdf_path] + else: + logger.warning( + "PDF not found for %s, falling back to full corpus", + entry.get("doc_name", ""), + ) + search_paths = [cfg.pdf_dir] + else: + search_paths = [cfg.pdf_dir] + + result = await searcher.search( + query=question, + paths=search_paths, + mode=cfg.mode, + top_k_files=cfg.top_k_files, + max_token_budget=cfg.max_token_budget, + enable_dir_scan=cfg.enable_dir_scan, + return_context=True, + ) + + raw_answer = getattr(result, "answer", "") or str(result) + + # Collect telemetry + read_files = list(getattr(result, "read_file_ids", None) or set()) + telemetry = { + "read_file_ids": read_files, + "total_tokens": getattr(result, "total_llm_tokens", 0), + "loop_count": getattr(result, "loop_count", 0), + "llm_calls": len(getattr(result, "llm_usages", None) or []), + "num_files_read": len(read_files), + } + + except Exception as exc: + error = str(exc) + logger.error("Error on %s: %s", fb_id, error) + + elapsed = time.time() - t0 + + # Delay between requests + if cfg.request_delay > 0: + await asyncio.sleep(cfg.request_delay) + + # --- LLM Judge evaluation (Accuracy + Coverage) --- + judge_correct = False + judge_reasoning = "" + judge_tokens = 0 + has_coverage = False + coverage_reasoning = "" + + if judge is not None: + # Accuracy evaluation + try: + judge_result = await judge.judge( + prediction=raw_answer, + gold_answer=gold, + question=question, + ) + judge_correct = judge_result.get("equivalent", False) + judge_reasoning = judge_result.get("reasoning", "") + judge_tokens += judge_result.get("tokens_used", 0) + except Exception as e: + logger.warning("LLM Judge (accuracy) failed for %s: %s", fb_id, e) + + # Coverage evaluation + try: + coverage_result = await judge.judge_coverage( + prediction=raw_answer, + question=question, + ) + has_coverage = coverage_result.get("has_coverage", False) + coverage_reasoning = coverage_result.get("reasoning", "") + judge_tokens += coverage_result.get("tokens_used", 0) + except Exception as e: + logger.warning("LLM Judge (coverage) failed for %s: %s", fb_id, e) + + return { + "financebench_id": fb_id, + "question": question, + "raw_prediction": raw_answer, + "gold_answer": gold, + "company": entry.get("company", ""), + "doc_name": entry.get("doc_name", ""), + "question_type": entry.get("question_type", ""), + "question_reasoning": entry.get("question_reasoning", ""), + "elapsed": round(elapsed, 2), + "telemetry": telemetry, + "judge_correct": judge_correct, + "judge_reasoning": judge_reasoning, + "coverage": has_coverage, + "coverage_reasoning": coverage_reasoning, + "judge_tokens": judge_tokens, + "error": error, + } + + +# ------------------------------------------------------------------ +# Batch execution +# ------------------------------------------------------------------ + + +async def run_batch( + samples: List[Dict[str, Any]], + cfg: FinanceBenchConfig, +) -> List[Dict[str, Any]]: + """Run all *samples* concurrently and persist results incrementally.""" + from sirchmunk.llm.openai_chat import OpenAIChat + from sirchmunk.search import AgenticSearch + + llm = OpenAIChat( + api_key=cfg.llm_api_key, + base_url=cfg.llm_base_url, + model=cfg.llm_model, + ) + work_path = str(Path(cfg.work_path).resolve()) + searcher = AgenticSearch(llm=llm, work_path=work_path, reuse_knowledge=False, verbose=False) + loader = FinanceBenchLoader(data_dir=cfg.data_dir, pdf_dir=cfg.pdf_dir) + semaphore = asyncio.Semaphore(cfg.max_concurrent) + + # Initialise LLM Judge + judge = None + if cfg.enable_llm_judge: + from judge import FinanceBenchLLMJudge + judge = FinanceBenchLLMJudge(llm=llm) + logger.info("LLM Judge enabled (drives Accuracy + Coverage)") + + # Prepare output directory / file + out_dir = Path(cfg.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + out_path = out_dir / f"financebench_{ts}.jsonl" + + results: list[dict] = [] + completed = 0 + total = len(samples) + + async def _run_and_record(entry: Dict[str, Any]) -> Dict[str, Any]: + nonlocal completed + res = await run_single(entry, loader, searcher, llm, cfg, semaphore, judge=judge) + # Incremental save + with open(out_path, "a", encoding="utf-8") as fp: + fp.write(json_mod.dumps(res, ensure_ascii=False) + "\n") + completed += 1 + acc_tag = "\u2713" if res["judge_correct"] else "\u2717" + cov_tag = "cov" if res["coverage"] else "no-cov" + logger.info( + "[%d/%d] %s [acc:%s] [%s] %.1fs", + completed, + total, + res["financebench_id"], + acc_tag, + cov_tag, + res["elapsed"], + ) + return res + + tasks = [asyncio.create_task(_run_and_record(s)) for s in samples] + results = await asyncio.gather(*tasks) + + # Write aggregate metrics + metrics = compute_metrics(list(results)) + metrics_path = out_dir / f"financebench_{ts}_metrics.json" + with open(metrics_path, "w", encoding="utf-8") as fp: + json_mod.dump(metrics, fp, indent=2, ensure_ascii=False) + logger.info("Metrics saved to %s", metrics_path) + logger.info( + "Accuracy=%.2f%% Coverage=%.2f%%", + metrics.get("accuracy", 0), + metrics.get("coverage", 0), + ) + + return list(results) diff --git a/config/env.example b/config/env.example index 8272d03..4b8dcd7 100644 --- a/config/env.example +++ b/config/env.example @@ -126,3 +126,8 @@ SIRCHMUNK_DEBUG=false # Maximum concurrent WebSocket connections (default: 100) SIRCHMUNK_MAX_WS_CONNECTIONS=100 + +# ===== Ablation Experiment Settings ===== +# Pure tree search mode (ablation experiment, default: false) +# When enabled, search relies solely on tree index navigation, skipping rga keyword search +# SIRCHMUNK_PURE_TREE_SEARCH=false diff --git a/requirements/core.txt b/requirements/core.txt index 1848a37..6cff25b 100644 --- a/requirements/core.txt +++ b/requirements/core.txt @@ -5,6 +5,7 @@ openai genson pillow pypdf +pdfminer.six pandas parquet numpy diff --git a/src/sirchmunk/agentic/tools.py b/src/sirchmunk/agentic/tools.py index b13e762..c79cf5a 100644 --- a/src/sirchmunk/agentic/tools.py +++ b/src/sirchmunk/agentic/tools.py @@ -568,3 +568,120 @@ async def execute( ) return result_text, {"query": query, "clusters_found": len(clusters)} + + +# --------------------------------------------------------------------------- +# Tool 5: Tree Navigation (medium cost — LLM-guided tree index navigation) +# --------------------------------------------------------------------------- + +class TreeNavigationTool(BaseTool): + """Navigate a document's compiled tree index to extract targeted evidence. + + Uses an LLM-driven tree navigation strategy: the model selects + the most relevant branches/sections from a hierarchical document + index, then extracts the corresponding page or char-range content. + + This tool bridges the gap between keyword search (which finds + *where* a term appears) and file read (which returns *everything*). + Tree navigation returns the most relevant *sections* of a document + without reading the whole file. + + Requires compile artifacts (tree indices) to be available for the + target files. + """ + + def __init__( + self, + navigate_fn: Any, + available_paths: Optional[set] = None, + max_chars: int = 15_000, + ) -> None: + self._navigate_fn = navigate_fn + self._available_paths = available_paths or set() + self._max_chars = max_chars + + @property + def name(self) -> str: + return "tree_navigate" + + def get_schema(self) -> Dict[str, Any]: + return { + "name": self.name, + "description": ( + "Navigate a document's compiled tree index to extract " + "targeted sections relevant to the query. More precise " + "than file_read — returns only relevant sections instead " + "of the entire file. Works with PDF, DOCX, and other " + "compiled document types. Medium token cost." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": ( + "Absolute path of the document to navigate." + ), + }, + "query": { + "type": "string", + "description": ( + "What information to look for in the document." + ), + }, + }, + "required": ["file_path", "query"], + }, + } + + async def execute( + self, + context: SearchContext, + **kwargs, + ) -> Tuple[str, Dict[str, Any]]: + file_path: str = kwargs.get("file_path", "") + query: str = kwargs.get("query", "") + if not file_path or not query: + return "file_path and query are required.", {} + + if ( + self._available_paths + and file_path not in self._available_paths + ): + return ( + f"No tree index available for {Path(file_path).name}. " + "Use file_read instead." + ), {"file_path": file_path, "indexed": False} + + try: + result = await self._navigate_fn( + file_path, query, max_chars=self._max_chars, + ) + except Exception as exc: + return ( + f"Tree navigation failed: {exc}" + ), {"file_path": file_path, "error": str(exc)} + + if not result: + return ( + f"No relevant sections found in " + f"{Path(file_path).name} for this query." + ), {"file_path": file_path, "chars": 0} + + total_chars = len(result) + approx_tokens = total_chars // 4 + context.add_log( + tool_name=self.name, + tokens=approx_tokens, + metadata={ + "file_path": file_path, + "chars": total_chars, + }, + ) + + header = f"[Tree navigation: {Path(file_path).name}]" + return f"{header}\n{result}", { + "file_path": file_path, + "chars": total_chars, + "tokens": approx_tokens, + } diff --git a/src/sirchmunk/cli/cli.py b/src/sirchmunk/cli/cli.py index 2b330e7..4aec43f 100644 --- a/src/sirchmunk/cli/cli.py +++ b/src/sirchmunk/cli/cli.py @@ -6,6 +6,7 @@ sirchmunk init - Initialize working directory + generate .env sirchmunk serve - Start the API server (backend only) sirchmunk search - Perform a search query + sirchmunk compile - Compile documents into knowledge indices sirchmunk web init - Build WebUI frontend (requires Node.js) sirchmunk web serve - Start API + WebUI (single port) sirchmunk web serve --dev - Start API + Next.js dev server (dual port) @@ -1237,6 +1238,227 @@ def cmd_mcp_version(args: argparse.Namespace) -> int: return 0 +# ------------------------------------------------------------------ +# sirchmunk compile +# ------------------------------------------------------------------ + + +def _configure_compile_threads() -> None: + """Set sensible thread-count defaults for CPU-bound ML workloads. + + Must be called early — before PyTorch, OpenMP, or kreuzberg's Rust + core are imported — so the environment variables are read at library + init time. User-provided overrides are always respected. + """ + cpu_count = os.cpu_count() or 4 + cap = str(max(1, min(cpu_count // 2, 4))) + for var in ("OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS", + "RAYON_NUM_THREADS"): + if var not in os.environ: + os.environ[var] = cap + + +def cmd_compile(args: argparse.Namespace) -> int: + """Compile document collections into structured knowledge indices. + + Builds PageIndex-style tree indices and LLM Wiki-style knowledge + clusters for downstream search acceleration. + + Args: + args: Command-line arguments + + Returns: + Exit code (0 for success, non-zero for failure) + """ + # Cap thread counts BEFORE heavy libraries are imported, so OpenMP/MKL + # read the correct values at init time. User-set vars are respected. + _configure_compile_threads() + + try: + work_path = Path( + getattr(args, "work_path", None) or str(_get_default_work_path()) + ).expanduser().resolve() + os.environ["SIRCHMUNK_WORK_PATH"] = str(work_path) + + env_file = work_path / ".env" + if env_file.exists(): + _load_env_file(env_file) + + paths = args.paths or None + if not paths: + print(" Error: --paths is required for compile.") + print(" Usage: sirchmunk compile --paths /data/docs") + return 1 + + # Status mode + if getattr(args, "status", False): + return asyncio.run(_compile_status(paths, work_path)) + + # Lint mode + if getattr(args, "lint", False): + return asyncio.run(_compile_lint( + work_path, auto_fix=getattr(args, "fix", False), + )) + + # Normal compile + incremental = not getattr(args, "full", False) + return asyncio.run(_compile_run( + paths=paths, + work_path=work_path, + incremental=incremental, + max_files=getattr(args, "max_files", None), + concurrency=getattr(args, "concurrency", 3), + shallow=getattr(args, "shallow", False), + )) + + except KeyboardInterrupt: + print("\n Compile cancelled.") + return 130 + except Exception as e: + logger.error(f"Compile failed: {e}", exc_info=True) + print(f" Compile error: {e}") + return 1 + + +async def _compile_run( + paths: list, + work_path: Path, + incremental: bool = True, + max_files: Optional[int] = None, + concurrency: int = 3, + shallow: bool = False, +) -> int: + """Execute compile using AgenticSearch.""" + from sirchmunk.search import AgenticSearch + from sirchmunk.llm.openai_chat import OpenAIChat + + llm_api_key = os.getenv("LLM_API_KEY", "") + if not llm_api_key: + print(" LLM_API_KEY is not set.") + print(" Configure it in ~/.sirchmunk/.env or set the environment variable.") + return 1 + + llm = OpenAIChat( + base_url=os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"), + api_key=llm_api_key, + model=os.getenv("LLM_MODEL_NAME", "gpt-5.2"), + ) + + searcher = AgenticSearch(llm=llm, work_path=str(work_path)) + + print("=" * 60) + print(" Sirchmunk Knowledge Compile") + print("=" * 60) + print() + print(f" Paths: {', '.join(paths)}") + print(f" Incremental: {incremental}") + if shallow: + print(" Mode: shallow (tree indexing skipped)") + if max_files: + print(f" Max files: {max_files} (importance sampling)") + print() + + report = await searcher.compile( + paths=paths, + incremental=incremental, + shallow=shallow, + max_files=max_files, + concurrency=concurrency, + ) + + print() + print("=" * 60) + print(" Compile Report") + print("=" * 60) + print() + print(f" Total files: {report.get('total_files', 0)}") + print(f" Files added: {report.get('files_added', 0)}") + print(f" Files modified: {report.get('files_modified', 0)}") + print(f" Files skipped: {report.get('files_skipped', 0)}") + if report.get("files_sampled"): + print(f" Files sampled: {report['files_sampled']}") + print(f" Trees built: {report.get('trees_built', 0)}") + print(f" Clusters created: {report.get('clusters_created', 0)}") + print(f" Clusters merged: {report.get('clusters_merged', 0)}") + print(f" Cross-refs: {report.get('cross_refs_built', 0)}") + print(f" Elapsed: {report.get('elapsed_seconds', 0):.1f}s") + if report.get("errors"): + print(f" Errors: {len(report['errors'])}") + for err in report["errors"][:5]: + print(f" - {err}") + print() + + return 0 + + +async def _compile_status(paths: list, work_path: Path) -> int: + """Show compile status.""" + from sirchmunk.search import AgenticSearch + from sirchmunk.llm.openai_chat import OpenAIChat + + llm = OpenAIChat( + base_url=os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"), + api_key=os.getenv("LLM_API_KEY", ""), + model=os.getenv("LLM_MODEL_NAME", "gpt-5.2"), + ) + + searcher = AgenticSearch(llm=llm, work_path=str(work_path)) + status = await searcher.compile_status(paths=paths) + + print("=" * 60) + print(" Compile Status") + print("=" * 60) + print() + print(f" Compiled files: {status.get('total_compiled_files', 0)}") + print(f" Tree indices: {status.get('total_trees', 0)}") + print(f" Clusters: {status.get('total_clusters', 0)}") + print(f" Last compile: {status.get('last_compile_at', 'Never')}") + print() + + return 0 + + +async def _compile_lint(work_path: Path, auto_fix: bool = False) -> int: + """Run knowledge lint checks.""" + from sirchmunk.search import AgenticSearch + from sirchmunk.llm.openai_chat import OpenAIChat + + llm = OpenAIChat( + base_url=os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"), + api_key=os.getenv("LLM_API_KEY", ""), + model=os.getenv("LLM_MODEL_NAME", "gpt-5.2"), + ) + + searcher = AgenticSearch(llm=llm, work_path=str(work_path)) + report = await searcher.compile_lint(auto_fix=auto_fix) + + print("=" * 60) + print(" Knowledge Lint Report") + print("=" * 60) + print() + print(f" Clusters checked: {report.get('total_clusters_checked', 0)}") + print(f" Trees checked: {report.get('total_trees_checked', 0)}") + print(f" Errors: {report.get('errors', 0)}") + print(f" Warnings: {report.get('warnings', 0)}") + if auto_fix: + print(f" Auto-fixes: {report.get('auto_fixes_applied', 0)}") + print() + + issues = report.get("issues", []) + if issues: + for issue in issues[:20]: + severity = issue.get("severity", "info").upper() + msg = issue.get("message", "") + cid = issue.get("cluster_id", "") + fixed = " [FIXED]" if issue.get("auto_fixed") else "" + print(f" [{severity}] {msg} {f'(cluster={cid})' if cid else ''}{fixed}") + if len(issues) > 20: + print(f" ... and {len(issues) - 20} more") + print() + + return 0 + + # ------------------------------------------------------------------ # sirchmunk upload # ------------------------------------------------------------------ @@ -1453,6 +1675,54 @@ def create_parser() -> argparse.ArgumentParser: ) search_parser.set_defaults(func=cmd_search) + # === compile command === + compile_parser = subparsers.add_parser( + "compile", + help="Compile document collections into knowledge indices", + description=( + "Compile documents into structured knowledge indices (tree + clusters). " + "Optional step after 'sirchmunk init'." + ), + ) + compile_parser.add_argument( + "--paths", nargs="+", required=True, + help="Directories or files to compile", + ) + compile_parser.add_argument( + "--full", action="store_true", default=False, + help="Force full recompile (ignore incremental cache)", + ) + compile_parser.add_argument( + "--max-files", type=int, default=None, + help="Max files to process (triggers importance sampling for large sets)", + ) + compile_parser.add_argument( + "--concurrency", type=int, default=3, + help="Max parallel file compilations (default: 3)", + ) + compile_parser.add_argument( + "--shallow", action="store_true", default=False, + help="Skip tree indexing — use direct LLM summarisation only (faster)", + ) + compile_parser.add_argument( + "--status", action="store_true", default=False, + help="Show compile status instead of running compile", + ) + compile_parser.add_argument( + "--lint", action="store_true", default=False, + help="Run knowledge health checks", + ) + compile_parser.add_argument( + "--fix", action="store_true", default=False, + help="Auto-fix lint issues (use with --lint)", + ) + compile_parser.add_argument( + "--work-path", + default=None, + help="Working directory (default: ~/.sirchmunk)", + ) + compile_parser.set_defaults(func=cmd_compile) + # === web command group === web_parser = subparsers.add_parser( "web", diff --git a/src/sirchmunk/learnings/README.md b/src/sirchmunk/learnings/README.md new file mode 100644 index 0000000..92bc22b --- /dev/null +++ b/src/sirchmunk/learnings/README.md @@ -0,0 +1,248 @@ +# Sirchmunk Learnings Module + +The `sirchmunk/learnings` module implements **knowledge compilation and continuous learning** capabilities. It houses the core logic for transforming raw document collections into structured, searchable knowledge networks. + +## Architecture Overview + +``` +learnings/ +├── __init__.py # Public API exports +├── knowledge_base.py # Runtime knowledge builder (search-time) +├── evidence_processor.py # Monte Carlo evidence sampling +├── compiler.py # Offline knowledge compiler (compile-time) +├── tree_indexer.py # PageIndex-style document tree indexer +├── lint.py # Knowledge network health checks +└── README.md # This file +``` + +### Design Philosophy + +The module fuses insights from three frameworks: + +1. **PageIndex** (VectifyAI) — Hierarchical tree indexing replaces brute-force vector search with LLM reasoning-based navigation. The key insight: *similarity ≠ relevance*. + +2. **LLM Wiki** (Karpathy) — Documents are not merely "indexed" but "compiled" into an interlinked knowledge network that compounds over time. Knowledge clusters grow richer with each compile cycle. + +3. **NotebookLM** (Google) — Strict source grounding ensures every claim traces back to original evidence. The `EvidenceUnit` system provides full provenance. + +### Compile vs. Search + +| Aspect | Compile (offline) | Search (runtime) | +|--------|-------------------|-------------------| +| **When** | `sirchmunk compile` | `sirchmunk search` | +| **Speed** | Minutes (batch) | Seconds (interactive) | +| **Purpose** | Build indices + knowledge | Answer queries | +| **Module** | `compiler.py` (uses `tree_indexer.py`) | `knowledge_base.py`, `evidence_processor.py` | +| **Required** | Optional | Always available | + +Compile products are automatically leveraged by search when present, but search functions independently without them. + +### How Search Consumes Compile Products + +``` +Compile products Search consumption path +───────────────── ────────────────────────────────────────────── +KnowledgeCluster ─┬─ FAST + DEEP Phase 0: embedding similarity + .content │ reuse (instant short-circuit, no LLM cost) + .embedding │ → enriched with evidence snippets + .evidences[].file_or_url │ + ├─ DEEP Phase 1: _probe_knowledge_cache() + │ fuzzy text search → file path discovery + │ +WeakSemanticEdge ├─ DEEP Phase 1: one-hop graph expansion + .related_clusters │ follows edges to gather neighbour files + │ +DocumentTree (.json) └─ DEEP Phase 3: tree-navigated evidence + via tree_indexer _build_cluster() → knowledge_base.build() + → _extract_evidence_for_file(tree_indexer) + → narrows doc to relevant sections before + Monte Carlo sampling +``` + +| Compile product | FAST | DEEP | +|-----------------|------|------| +| Cluster embedding reuse | Yes | Yes | +| Evidence snippets in reused content | Yes | Yes | +| Fuzzy cluster → file path hints | — | Yes | +| Graph edge expansion (neighbours) | — | Yes | +| Tree-navigated evidence extraction | — | Yes | + +--- + +## Components + +### DocumentTreeIndexer (`tree_indexer.py`) + +Builds hierarchical JSON tree indices for structured long documents. + +**Key concepts:** +- Only triggers for documents ≥ 50KB in eligible formats (PDF, DOCX, MD, HTML, etc.) +- LLM analyzes document structure recursively (up to 4 levels deep) +- Each node stores: title, summary, character range +- Query-time navigation: LLM selects relevant branches instead of scanning everything + +**Data structures:** +- `TreeNode` — Single node with `node_id`, `title`, `summary`, `char_range`, `children` +- `DocumentTree` — Complete tree for a document, JSON-serializable, cached by file hash + +**Usage:** +```python +indexer = DocumentTreeIndexer(llm=llm, cache_dir=cache_path) + +# Build (async, LLM-powered) +tree = await indexer.build_tree(file_path, content, max_depth=4) + +# Navigate (async, LLM-powered branch selection) +leaves = await indexer.navigate(tree, query="How does X work?") +for leaf in leaves: + relevant_text = content[leaf.char_range[0]:leaf.char_range[1]] + +# Cache check (sync) +if indexer.has_tree(file_path): + tree = indexer.load_tree(file_path) +``` + +### KnowledgeCompiler (`compiler.py`) + +Orchestrates the unified compile pipeline. + +**Four-phase pipeline:** +1. **File Discovery & Change Detection** — Scans paths, compares with manifest for incremental processing +2. **Per-File Compile** — Unified pipeline per file: tree-if-eligible → summary → topics → rich evidence +3. **Knowledge Aggregation** — Merges into existing clusters or creates new ones (three-tier similarity) +4. **Cross-Reference Building** — Creates `WeakSemanticEdge` links between related clusters + +**Unified single-file pipeline:** +For each file, the compiler runs a single pipeline instead of separate "tree" and "wiki" modes: +- If the file is ≥ 50KB and in an eligible format, a tree is built first. The root node's summary is synthesized from children's section summaries via LLM, and `EvidenceUnit` snippets + `tree_path` are populated directly from tree leaves. +- If the file is small or `shallow=True`, a direct LLM summary is generated instead. +- In both cases, topics are extracted and a `KnowledgeCluster` is created/merged. + +**Three-tier similarity strategy:** +| Similarity | Action | +|-----------|--------| +| ≥ 0.80 | Merge into existing cluster, re-compute embedding | +| 0.50 – 0.79 | Create new cluster + build `embed_sim` weak edges | +| < 0.50 | Create standalone cluster | + +**Importance probability sampling** (`ImportanceSampler`): +For large datasets, select a representative subset using weighted random sampling: +- File size (log-scaled): larger files contain more information +- Novelty: uncompiled files get 4× weight over already-compiled ones +- Extension diversity: structured formats (PDF, DOCX) get 1.5× boost + +**Key data structures:** +- `CompileManifest` — Tracks file hashes and cluster associations for incremental compile +- `FileManifestEntry` — Per-file state (hash, compile timestamp, tree flag, cluster IDs) +- `CompileReport` — Statistics from a compile run +- `CompileStatus` — Quick status snapshot + +### KnowledgeLint (`lint.py`) + +Health checks for the knowledge network (inspired by LLM Wiki's Lint operation). + +**Checks performed:** +- **Empty clusters** — Clusters with minimal or no content +- **Stale evidence** — Evidence pointing to files that no longer exist +- **Orphan clusters** — Clusters with no evidence and no queries +- **Isolated clusters** — Clusters with no cross-references +- **Orphan trees** — Tree cache files without matching manifest entries +- **Stale manifest** — Manifest entries pointing to deleted files + +**Auto-fix capabilities:** +- Deprecate clusters where all evidence sources are gone +- Remove orphan tree cache files + +### KnowledgeBase (`knowledge_base.py`) + +Runtime knowledge builder used during search operations. + +**Tree-aware evidence extraction:** +When a tree index exists for a file, `_extract_evidence_for_file()` navigates to relevant sections first, then runs Monte Carlo sampling within those narrowed regions. This dramatically improves precision for large documents. + +### MonteCarloEvidenceSampling (`evidence_processor.py`) + +Statistical sampling for finding relevant regions in documents. Used both at compile-time and search-time. + +--- + +## CLI Interface + +```bash +# Compile documents (optional, after sirchmunk init) +sirchmunk compile --paths /data/docs /data/reports + +# Incremental compile (default, skips unchanged files) +sirchmunk compile --paths /data/docs + +# Full recompile +sirchmunk compile --paths /data/docs --full + +# Importance sampling for large datasets +sirchmunk compile --paths /data/docs --max-files 100 + +# Shallow mode: skip tree indexing, use direct LLM summarisation +sirchmunk compile --paths /data/docs --shallow + +# Check compile status +sirchmunk compile --paths /data/docs --status + +# Run health checks +sirchmunk compile --paths /data/docs --lint +sirchmunk compile --paths /data/docs --lint --fix +``` + +## Python SDK + +```python +from sirchmunk.search import AgenticSearch + +searcher = AgenticSearch(work_path="~/.sirchmunk") + +# Compile +report = await searcher.compile( + paths=["/data/docs"], + incremental=True, + shallow=False, # set True to skip tree indexing + max_files=100, # importance sampling + concurrency=3, +) + +# Status +status = await searcher.compile_status(paths=["/data/docs"]) + +# Lint +lint_report = await searcher.compile_lint(auto_fix=True) + +# Search (automatically uses compile products when available) +result = await searcher.search("query", paths=["/data/docs"]) +``` + +--- + +## Cache Layout + +``` +{work_path}/.cache/ +├── compile/ +│ ├── manifest.json # Compile manifest (incremental state) +│ └── trees/ +│ ├── {file_hash_1}.json # Tree index for document 1 +│ └── {file_hash_2}.json # Tree index for document 2 +└── knowledge/ + └── knowledge_clusters.parquet # Persistent cluster storage (DuckDB + Parquet) +``` + +## Schema Extensions + +The compile feature extends existing schemas: + +- **`EvidenceUnit`** — Added `tree_path` (node IDs from tree navigation) and `page_range` (character offsets) +- **`KnowledgeCluster`** — Added `merge_count` (tracks compile-time merge frequency for lifecycle promotion: ≥ 3 merges → `STABLE`) + +## Design Principles + +- **SOLID compliance**: Each class has a single responsibility; dependencies are injected via constructor +- **Optional by design**: Compile never breaks existing search functionality +- **Incremental**: Only processes changed files; manifest tracks state across runs +- **Production-ready**: Bounded concurrency, error isolation per file, graceful schema migration diff --git a/src/sirchmunk/learnings/__init__.py b/src/sirchmunk/learnings/__init__.py index 0829846..bc14211 100644 --- a/src/sirchmunk/learnings/__init__.py +++ b/src/sirchmunk/learnings/__init__.py @@ -1 +1,28 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. \ No newline at end of file +# Copyright (c) ModelScope Contributors. All rights reserved. + +from sirchmunk.learnings.compiler import ( + CompileManifest, + CompileReport, + CompileStatus, + ImportanceSampler, + KnowledgeCompiler, +) +from sirchmunk.learnings.lint import KnowledgeLint, LintReport +from sirchmunk.learnings.tree_indexer import ( + DocumentTree, + DocumentTreeIndexer, + TreeNode, +) + +__all__ = [ + "CompileManifest", + "CompileReport", + "CompileStatus", + "DocumentTree", + "DocumentTreeIndexer", + "ImportanceSampler", + "KnowledgeCompiler", + "KnowledgeLint", + "LintReport", + "TreeNode", +] \ No newline at end of file diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py new file mode 100644 index 0000000..e812bef --- /dev/null +++ b/src/sirchmunk/learnings/compiler.py @@ -0,0 +1,2743 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Knowledge compiler — orchestrates offline compile of document collections. + +Fuses PageIndex (tree indexing) and LLM Wiki (knowledge compilation network) +into a single compile pipeline that produces structured tree indices and +knowledge clusters for downstream search acceleration. +""" + +import asyncio +import bisect +import ctypes +import gc +import json +import math +import os +import platform +import random +import re +import hashlib +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from sirchmunk.learnings.tree_indexer import ( + DocumentTree, + DocumentTreeIndexer, +) +from sirchmunk.llm.openai_chat import OpenAIChat +from sirchmunk.schema.knowledge import ( + AbstractionLevel, + EvidenceUnit, + KnowledgeCluster, + Lifecycle, + WeakSemanticEdge, +) +from sirchmunk.storage.knowledge_storage import KnowledgeStorage +from sirchmunk.utils import LogCallback, create_logger +from sirchmunk.utils.document_extractor import DocumentExtractor +from sirchmunk.utils.file_utils import get_fast_hash + +# Concurrency cap for LLM-heavy file processing +_DEFAULT_CONCURRENCY = 3 + +# Similarity threshold for merging into existing clusters during compile +_MERGE_SIMILARITY_THRESHOLD = 0.75 + +# Max chars for manifest-persisted document summary (used in Phase 2 & catalog) +_MANIFEST_SUMMARY_MAX_LEN = 500 + +# Preview window for direct LLM summarisation (no tree), ~4K tokens +_SUMMARY_PREVIEW_CHARS = 16_000 + +# Multi-section sampling for large documents without a tree index +_SUMMARY_SAMPLE_SECTIONS = 3 # Number of sections to sample for large docs +_SUMMARY_SAMPLE_SECTION_CHARS = 5_000 # Chars per sampled section + +# Targeted table extraction: max chars per table region +_TARGETED_TABLE_MAX_CHARS = 5000 + +# Targeted table extraction: only process nodes spanning <= N pages +_TABLE_PAGE_SPAN_LIMIT = 5 + +# Numeric density threshold – fraction of numeric/symbol chars ($, %, digits, +# parenthesised numbers) relative to total non-whitespace chars. Pages below +# this threshold are skipped during targeted extraction. +_TABLE_NUMERIC_DENSITY_THRESHOLD = 0.15 + +# Selective force-OCR: max pages to re-extract with forced OCR per document +_FORCE_OCR_MAX_PAGES = 30 + +# Incremental manifest flush: persist manifest every N completed files +# to survive interrupted compiles without excessive I/O overhead. +_MANIFEST_FLUSH_INTERVAL = 10 + +# Page-level extraction: max pages to load into memory per batch. +# Prevents loading all 200-400 pages of a large PDF at once. +_PAGE_SCAN_BATCH_SIZE = 50 + +# How often to run gc.collect() inside the compile loop (every N files). +_GC_INTERVAL = 5 + + +def _force_gc() -> None: + """Aggressively reclaim Python-managed memory and nudge the C allocator.""" + gc.collect() + if platform.system() == "Linux": + try: + ctypes.CDLL("libc.so.6").malloc_trim(0) + except (OSError, AttributeError): + pass + + +# Shared numeric-token regex for table detection heuristics. +# Matches: $1,234 (1,234) 12.5% 3.14e-5 1,000 +_NUM_TOKEN_RE = re.compile( + r"(?:" + r"[\$€£¥]\s*[\d,.]+|" + r"\([\d,.]+\)|" + r"[\d,.]+%|" + r"[\d]+\.[\d]+(?:[eE][+-]?\d+)?|" + r"[\d,]{2,}" + r")" +) + +# A single line with >= this many numeric tokens is treated as a dense +# table row (or multiple rows concatenated), enabling detection even when +# pypdf flattens the entire page to one or two lines. +_DENSE_LINE_MIN_TOKENS = 15 + +# --------------------------------------------------------------------------- +# Heading normalisation: candidate extraction patterns +# --------------------------------------------------------------------------- +# kreuzberg sometimes renders section titles as ``**bold text**`` or bare +# short standalone lines instead of ``## heading``. The tree indexer can +# only split on markdown headings, so these "invisible" titles get absorbed +# into parent nodes. +# +# We extract *candidates* via lightweight regexes and let the LLM classify +# which ones are genuine section headings (language/domain-agnostic). + +_BOLD_LINE_RE = re.compile( + r"^\*\*((?:(?!\*\*).)+)\*\*\s*$", + re.MULTILINE, +) + +_STANDALONE_LINE_RE = re.compile( + r"(?:^|\n\n)([^\n]{5,100})\n\n", +) + +_HEADING_CANDIDATE_CAP = 40 + +# Excel table-level adaptive sampling constants +_XLSX_TOTAL_ROW_BUDGET = 100 # Total sampled rows budget across all sheets +_XLSX_MIN_ROWS_PER_SHEET = 3 # Minimum sampled rows per sheet +_XLSX_MAX_ROWS_PER_SHEET = 50 # Maximum sampled rows per sheet +_XLSX_MAX_SHEETS = 10 # Maximum number of sheets to process +_XLSX_MAX_COLS_DISPLAY = 20 # Maximum columns to display per sheet + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + +@dataclass +class FileManifestEntry: + """State of a single file in the compile manifest.""" + + file_hash: str + compiled_at: str + has_tree: bool + cluster_ids: List[str] + size_bytes: int + summary: str = "" # 新增:存储编译期生成的文档摘要 + has_explicit_toc: bool = False # Whether a native TOC was extracted from the file + tree_node_count: int = 0 # Number of nodes in the tree index (quality metric) + has_xlsx_digest: bool = False # Whether a pre-compiled Excel evidence digest exists + has_table_digest: bool = False # Whether PDF tables were extracted and stored + table_count: int = 0 # Number of tables in this file + + def to_dict(self) -> Dict[str, Any]: + return { + "file_hash": self.file_hash, + "compiled_at": self.compiled_at, + "has_tree": self.has_tree, + "cluster_ids": self.cluster_ids, + "size_bytes": self.size_bytes, + "summary": self.summary, + "has_explicit_toc": self.has_explicit_toc, + "tree_node_count": self.tree_node_count, + "has_xlsx_digest": self.has_xlsx_digest, + "has_table_digest": self.has_table_digest, + "table_count": self.table_count, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "FileManifestEntry": + return cls( + file_hash=data["file_hash"], + compiled_at=data["compiled_at"], + has_tree=data.get("has_tree", False), + cluster_ids=data.get("cluster_ids", []), + size_bytes=data.get("size_bytes", 0), + summary=data.get("summary", ""), + has_explicit_toc=data.get("has_explicit_toc", False), + tree_node_count=data.get("tree_node_count", 0), + has_xlsx_digest=data.get("has_xlsx_digest", False), + has_table_digest=data.get("has_table_digest", False), + table_count=data.get("table_count", 0), + ) + + +@dataclass +class CompileManifest: + """Tracks compiled file states for incremental processing.""" + + version: str = "1.0" + last_compile_at: Optional[str] = None + files: Dict[str, FileManifestEntry] = field(default_factory=dict) + + def to_json(self) -> str: + return json.dumps({ + "version": self.version, + "last_compile_at": self.last_compile_at, + "files": {k: v.to_dict() for k, v in self.files.items()}, + }, ensure_ascii=False, indent=2) + + @classmethod + def from_json(cls, json_str: str) -> "CompileManifest": + data = json.loads(json_str) + files = { + k: FileManifestEntry.from_dict(v) + for k, v in data.get("files", {}).items() + } + return cls( + version=data.get("version", "1.0"), + last_compile_at=data.get("last_compile_at"), + files=files, + ) + + +@dataclass +class FileEntry: + """Discovered file pending compilation.""" + + path: str + size_bytes: int + file_hash: str + + +@dataclass +class ChangeSet: + """Delta between discovered files and the manifest.""" + + added: List[FileEntry] = field(default_factory=list) + modified: List[FileEntry] = field(default_factory=list) + deleted: List[str] = field(default_factory=list) + unchanged: List[str] = field(default_factory=list) + + +@dataclass +class FileCompileResult: + """Result of compiling a single file.""" + + path: str + tree: Optional[DocumentTree] = None + summary: str = "" + topics: List[str] = field(default_factory=list) + evidence: Optional[EvidenceUnit] = None + cluster_ids: List[str] = field(default_factory=list) + error: Optional[str] = None + has_explicit_toc: bool = False # Whether TOC was extracted from native structure + tree_node_count: int = 0 # Number of nodes in the tree index + has_xlsx_digest: bool = False # Whether a pre-compiled Excel evidence digest exists + has_table_digest: bool = False # Whether a pre-compiled table digest exists + table_count: int = 0 # Number of tables extracted + + +@dataclass +class CompileReport: + """Summary report of a compile run.""" + + total_files: int = 0 + files_added: int = 0 + files_modified: int = 0 + files_skipped: int = 0 + files_deleted: int = 0 + files_sampled: int = 0 + trees_built: int = 0 + clusters_created: int = 0 + clusters_merged: int = 0 + cross_refs_built: int = 0 + errors: List[str] = field(default_factory=list) + elapsed_seconds: float = 0.0 + + def to_dict(self) -> Dict[str, Any]: + return { + "total_files": self.total_files, + "files_added": self.files_added, + "files_modified": self.files_modified, + "files_skipped": self.files_skipped, + "files_deleted": self.files_deleted, + "files_sampled": self.files_sampled, + "trees_built": self.trees_built, + "clusters_created": self.clusters_created, + "clusters_merged": self.clusters_merged, + "cross_refs_built": self.cross_refs_built, + "errors": self.errors, + "elapsed_seconds": round(self.elapsed_seconds, 2), + } + + +@dataclass +class CompileStatus: + """Status snapshot of the compile state.""" + + total_compiled_files: int = 0 + total_clusters: int = 0 + total_trees: int = 0 + last_compile_at: Optional[str] = None + manifest_path: str = "" + + +# --------------------------------------------------------------------------- +# Importance probability sampler +# --------------------------------------------------------------------------- + +class ImportanceSampler: + """Select a representative subset of files using importance-based probability. + + Sampling strategy for large datasets: + - Larger files get higher probability (they contain more information). + - Uncompiled (new) files are prioritised over previously compiled ones. + - Files with rare extensions get a mild boost (diversity signal). + - The final probability is proportional to a composite importance score. + """ + + def __init__(self, max_files: int, seed: Optional[int] = None): + self._max_files = max_files + self._rng = random.Random(seed) + + def sample(self, files: List[FileEntry], manifest: CompileManifest) -> List[FileEntry]: + """Return up to *max_files* entries sampled by importance.""" + if len(files) <= self._max_files: + return files + + scores = [self._score(f, manifest) for f in files] + total = sum(scores) or 1.0 + probs = [s / total for s in scores] + + selected_indices = set() + attempts = 0 + while len(selected_indices) < self._max_files and attempts < len(files) * 3: + idx = self._weighted_choice(probs) + selected_indices.add(idx) + attempts += 1 + + return [files[i] for i in sorted(selected_indices)] + + def _score(self, entry: FileEntry, manifest: CompileManifest) -> float: + """Compute composite importance score.""" + # Size factor: log-scaled, bounded + size_score = math.log2(max(entry.size_bytes, 1024)) / 20.0 + + # Novelty factor: new files are more important + novelty = 2.0 if entry.path not in manifest.files else 0.5 + + # Extension diversity: rare extensions get a mild boost + ext = Path(entry.path).suffix.lower() + diversity = 1.5 if ext in {".pdf", ".docx", ".doc", ".tex"} else 1.0 + + return size_score * novelty * diversity + + def _weighted_choice(self, probs: List[float]) -> int: + r = self._rng.random() + cumulative = 0.0 + for i, p in enumerate(probs): + cumulative += p + if r <= cumulative: + return i + return len(probs) - 1 + + +# --------------------------------------------------------------------------- +# Compiler +# --------------------------------------------------------------------------- + +class KnowledgeCompiler: + """Orchestrate compile pipeline: file discovery -> tree indexing -> knowledge aggregation.""" + + # File extensions eligible for compilation + _ELIGIBLE_EXTENSIONS = { + ".pdf", ".docx", ".doc", ".md", ".markdown", ".html", ".htm", + ".rst", ".tex", ".txt", ".pptx", ".xlsx", + } + + def __init__( + self, + llm: OpenAIChat, + embedding_client: Optional[Any], + knowledge_storage: KnowledgeStorage, + tree_indexer: DocumentTreeIndexer, + work_path: Union[str, Path], + log_callback: LogCallback = None, + ): + self._llm = llm + self._embedding = embedding_client + self._storage = knowledge_storage + self._tree_indexer = tree_indexer + self._work_path = Path(work_path).expanduser().resolve() + self._log = create_logger(log_callback=log_callback) + + self._compile_dir = self._work_path / ".cache" / "compile" + self._compile_dir.mkdir(parents=True, exist_ok=True) + self._manifest_path = self._compile_dir / "manifest.json" + + # ------------------------------------------------------------------ # + # Resource management # + # ------------------------------------------------------------------ # + + @staticmethod + def _configure_thread_limits() -> None: + """Cap PyTorch thread count to reduce per-thread memory allocation. + + Environment variables (OMP_NUM_THREADS, etc.) are set in the CLI + entry point before libraries are imported. This method handles the + PyTorch-specific runtime API that works retroactively. + """ + cpu_count = os.cpu_count() or 4 + cap = max(1, min(cpu_count // 2, 4)) + try: + import torch + torch.set_num_threads(cap) + torch.set_num_interop_threads(max(1, cap // 2)) + except (ImportError, RuntimeError): + pass + + # ------------------------------------------------------------------ # + # Public API # + # ------------------------------------------------------------------ # + + async def compile( + self, + paths: List[str], + *, + incremental: bool = True, + shallow: bool = False, + max_files: Optional[int] = None, + concurrency: int = _DEFAULT_CONCURRENCY, + ) -> CompileReport: + """Execute the unified knowledge compile pipeline. + + Args: + paths: Directories or files to compile. + incremental: Skip unchanged files. + shallow: Skip tree building even for eligible files — use direct + LLM summarisation only (faster, lower quality). + max_files: Cap on files to process (triggers importance sampling). + concurrency: Max parallel file compilations. + """ + import time + + self._configure_thread_limits() + + t0 = time.monotonic() + report = CompileReport() + + # Phase 1: discover and diff + await self._log.info("[Compile] Phase 1: File discovery & change detection") + manifest = self._load_manifest() + discovered = await self._discover_files(paths) + report.total_files = len(discovered) + await self._log.info(f"[Compile] Discovered {len(discovered)} eligible files") + + if incremental: + changes = self._detect_changes(discovered, manifest) + to_compile = changes.added + changes.modified + report.files_skipped = len(changes.unchanged) + report.files_deleted = len(changes.deleted) + + stale_paths = changes.deleted + [e.path for e in changes.modified] + if stale_paths: + await self._purge_stale_artifacts(stale_paths, manifest) + else: + to_compile = discovered + report.files_skipped = 0 + + report.files_added = len([f for f in to_compile if f.path not in manifest.files]) + report.files_modified = len(to_compile) - report.files_added + + # Phase 1.5: importance sampling for large datasets + if max_files and len(to_compile) > max_files: + await self._log.info( + f"[Compile] Applying importance sampling: {len(to_compile)} -> {max_files} files" + ) + sampler = ImportanceSampler(max_files=max_files) + to_compile = sampler.sample(to_compile, manifest) + report.files_sampled = len(to_compile) + + if not to_compile: + await self._log.info("[Compile] No files to compile (all up-to-date)") + report.elapsed_seconds = time.monotonic() - t0 + return report + + await self._log.info( + f"[Compile] Phase 2: Processing {len(to_compile)} files " + f"(concurrency={concurrency})" + ) + + # Phase 2 + 3 (fused): compile files, aggregate inline, release heavy objects + # Fusing Phase 3 into the completion loop avoids retaining all + # DocumentTree / EvidenceUnit objects until the end of the pipeline. + semaphore = asyncio.Semaphore(concurrency) + _xref_pairs: List[Tuple[str, List[str]]] = [] # lightweight (path, cluster_ids) for Phase 4 + _files_since_flush = 0 + _files_since_gc = 0 + + async def _bounded(entry: FileEntry) -> FileCompileResult: + async with semaphore: + return await self._compile_single_file(entry, shallow=shallow) + + tasks = [_bounded(f) for f in to_compile] + for coro in asyncio.as_completed(tasks): + result = await coro + if result.error: + report.errors.append(f"{result.path}: {result.error}") + else: + if result.tree: + report.trees_built += 1 + manifest.files[result.path] = FileManifestEntry( + file_hash=get_fast_hash(result.path) or "", + compiled_at=datetime.now(timezone.utc).isoformat(), + has_tree=result.tree is not None, + cluster_ids=result.cluster_ids, + size_bytes=Path(result.path).stat().st_size if Path(result.path).exists() else 0, + summary=result.summary[:_MANIFEST_SUMMARY_MAX_LEN] if result.summary else "", + has_explicit_toc=result.has_explicit_toc, + tree_node_count=result.tree_node_count, + has_xlsx_digest=result.has_xlsx_digest, + has_table_digest=result.has_table_digest, + table_count=result.table_count, + ) + + # Phase 3 inline: aggregate while the result is still alive + if not result.error and result.summary: + created, merged = await self._aggregate_to_knowledge_network(result) + report.clusters_created += created + report.clusters_merged += merged + + # Retain only lightweight cross-ref data, then drop the heavy result + _xref_pairs.append((result.path, list(result.cluster_ids))) + del result + + # Incremental manifest flush to survive interrupted compiles + _files_since_flush += 1 + if _files_since_flush >= _MANIFEST_FLUSH_INTERVAL: + manifest.last_compile_at = datetime.now(timezone.utc).isoformat() + self._save_manifest(manifest) + _files_since_flush = 0 + + _files_since_gc += 1 + if _files_since_gc >= _GC_INTERVAL: + _force_gc() + _files_since_gc = 0 + + # Phase 2 checkpoint: persist manifest before cross-references + manifest.last_compile_at = datetime.now(timezone.utc).isoformat() + self._save_manifest(manifest) + + # Phase 4: cross-references (uses only lightweight path+cluster_ids pairs) + await self._log.info("[Compile] Phase 4: Building cross-references") + report.cross_refs_built = await self._build_cross_references_from_pairs( + _xref_pairs, manifest, + ) + + # Phase 5: persist final manifest + derived indices + # Catalog and summary index are rebuilt from the manifest, so even + # partial compiles produce usable search-time metadata. + manifest.last_compile_at = datetime.now(timezone.utc).isoformat() + self._save_manifest(manifest) + self._storage.force_sync() + + self._build_document_catalog(manifest) + + await self._build_summary_index(manifest) + + report.elapsed_seconds = time.monotonic() - t0 + await self._log.info( + f"[Compile] Done in {report.elapsed_seconds:.1f}s — " + f"trees={report.trees_built}, created={report.clusters_created}, " + f"merged={report.clusters_merged}, errors={len(report.errors)}" + ) + return report + + async def get_status(self, paths: List[str]) -> CompileStatus: + """Return current compile status for the given paths.""" + manifest = self._load_manifest() + path_set = {str(Path(p).resolve()) for p in paths} + + compiled_count = 0 + tree_count = 0 + cluster_ids: Set[str] = set() + for fp, entry in manifest.files.items(): + for p in path_set: + if fp.startswith(p): + compiled_count += 1 + if entry.has_tree: + tree_count += 1 + cluster_ids.update(entry.cluster_ids) + break + + return CompileStatus( + total_compiled_files=compiled_count, + total_clusters=len(cluster_ids), + total_trees=tree_count, + last_compile_at=manifest.last_compile_at, + manifest_path=str(self._manifest_path), + ) + + # ------------------------------------------------------------------ # + # File discovery and change detection # + # ------------------------------------------------------------------ # + + async def _discover_files(self, paths: List[str]) -> List[FileEntry]: + """Walk paths and return all compilation-eligible files.""" + entries: List[FileEntry] = [] + seen: Set[str] = set() + + for base in paths: + base_path = Path(base).expanduser().resolve() + if base_path.is_file(): + candidates = [base_path] + elif base_path.is_dir(): + candidates = sorted(base_path.rglob("*")) + else: + continue + + for fp in candidates: + if not fp.is_file(): + continue + if fp.suffix.lower() not in self._ELIGIBLE_EXTENSIONS: + continue + abs_path = str(fp.resolve()) + if abs_path in seen: + continue + seen.add(abs_path) + fh = get_fast_hash(abs_path) + if fh is None: + continue + entries.append(FileEntry( + path=abs_path, + size_bytes=fp.stat().st_size, + file_hash=fh, + )) + + return entries + + def _detect_changes( + self, discovered: List[FileEntry], manifest: CompileManifest, + ) -> ChangeSet: + """Compare discovered files against the manifest for incremental compile.""" + changes = ChangeSet() + current_paths = {f.path for f in discovered} + + for entry in discovered: + prev = manifest.files.get(entry.path) + if prev is None: + changes.added.append(entry) + elif prev.file_hash != entry.file_hash: + changes.modified.append(entry) + else: + changes.unchanged.append(entry.path) + + for old_path in manifest.files: + if old_path not in current_paths: + changes.deleted.append(old_path) + + return changes + + # ------------------------------------------------------------------ # + # Stale artifact cleanup # + # ------------------------------------------------------------------ # + + async def _purge_stale_artifacts( + self, + file_paths: List[str], + manifest: CompileManifest, + ) -> None: + """Remove disk artifacts and DuckDB clusters for deleted/modified files. + + Called before recompilation so that modified files start with a + clean slate and deleted files leave no residue. + """ + artifact_dirs = { + "trees": ".json", + "content": ".txt", + "table_digests": ".json", + "xlsx_digests": ".txt", + } + + for file_path in file_paths: + entry = manifest.files.get(file_path) + if entry is None: + continue + + file_hash = entry.file_hash + + # 1. Remove disk artifacts keyed by file_hash + if file_hash: + for subdir, ext in artifact_dirs.items(): + artifact = self._compile_dir / subdir / f"{file_hash}{ext}" + try: + artifact.unlink(missing_ok=True) + except OSError: + pass + + # 2. Remove associated knowledge clusters from DuckDB + for cluster_id in entry.cluster_ids: + try: + await self._storage.remove(cluster_id) + except Exception: + pass + + # 3. Drop the manifest entry + manifest.files.pop(file_path, None) + + # ------------------------------------------------------------------ # + # Single-file compilation # + # ------------------------------------------------------------------ # + + async def _compile_single_file( + self, + entry: FileEntry, + *, + shallow: bool = False, + ) -> FileCompileResult: + """Unified compile pipeline: tree-if-eligible -> summary -> topics -> evidence. + + When *shallow* is True (or file is ineligible for tree indexing), + the pipeline skips tree building and summarises via a direct LLM call. + + Large intermediate objects (extraction output, enriched content, + raw tables) are explicitly released after their last use to keep + per-file peak memory bounded. + """ + result = FileCompileResult(path=entry.path) + try: + await self._log.info(f"[Compile] Processing: {Path(entry.path).name}") + + extraction = await DocumentExtractor.extract_isolated( + entry.path, DocumentExtractor.ENHANCED, + ) + content = extraction.content + content = await self._normalize_bold_headings(content) + if not content or len(content.strip()) < 100: + result.error = "Insufficient text content" + return result + + # Extract scalar metadata from extraction before releasing it + page_count = extraction.page_count + raw_tables = extraction.tables + del extraction + + use_tree = ( + not shallow + and DocumentTreeIndexer.should_build_tree(entry.path, len(content)) + ) + + # Phase 0.5: TOC extraction (layers 1-3 are zero LLM calls) + toc_entries = None + if use_tree: + from sirchmunk.learnings.toc_extractor import TOCExtractor + toc_entries = await TOCExtractor.extract( + entry.path, content, + total_pages=page_count, + ) + if toc_entries: + await self._log.info( + f"[Compile] Extracted TOC with {len(toc_entries)} entries " + f"for {Path(entry.path).name}" + ) + + if use_tree: + result.tree = await self._tree_indexer.build_tree( + entry.path, content, + toc_entries=toc_entries, + total_pages=page_count, + ) + + result.has_explicit_toc = bool(toc_entries) + del toc_entries + result.tree_node_count = self._count_tree_nodes(result.tree) + print(f"SEARCH_WIKI_DEBUG [C2] tree_build: success={result.tree is not None}, nodes={result.tree_node_count}, tree.file_path={result.tree.file_path if result.tree else 'N/A'}", flush=True) + + # --- Summary + topics + evidence (needs content) --- + ext = Path(entry.path).suffix.lower() + evidence_digest = "" + + if ext in (".xlsx", ".xls"): + metadata_prefix, evidence_digest = self._extract_xlsx_sampling(entry.path) + else: + metadata_prefix = self._extract_structured_metadata(entry.path, content) + + # Build enriched_content only for the summary LLM call, then release + if metadata_prefix: + result.summary = await self._extract_summary( + entry.path, metadata_prefix + content, result.tree, + ) + else: + result.summary = await self._extract_summary( + entry.path, content, result.tree, + ) + del metadata_prefix + + result.topics = await self._extract_topics(result.summary) + result.evidence = self._build_evidence(entry, content, result) + + # Persist Excel evidence digest + if evidence_digest.strip(): + try: + digest_dir = self._compile_dir / "xlsx_digests" + digest_dir.mkdir(parents=True, exist_ok=True) + file_hash = get_fast_hash(entry.path) or "" + if file_hash: + (digest_dir / f"{file_hash}.txt").write_text( + evidence_digest, encoding="utf-8", + ) + result.has_xlsx_digest = True + except Exception: + pass + del evidence_digest + + # Cache ENHANCED content to disk + try: + file_hash_content = get_fast_hash(entry.path) or "" + if file_hash_content and content: + content_dir = self._compile_dir / "content" + content_dir.mkdir(parents=True, exist_ok=True) + (content_dir / f"{file_hash_content}.txt").write_text( + content, encoding="utf-8", + ) + except Exception: + pass + + # --- Table digest + integration (needs raw_tables, then release) --- + if raw_tables: + try: + table_digest = self._build_table_digest(raw_tables) + if table_digest: + digest_dir = self._compile_dir / "table_digests" + digest_dir.mkdir(parents=True, exist_ok=True) + file_hash = get_fast_hash(entry.path) or "" + if file_hash: + (digest_dir / f"{file_hash}.json").write_text( + json.dumps(table_digest, ensure_ascii=False), + encoding="utf-8", + ) + result.has_table_digest = True + result.table_count = len(raw_tables) + except Exception: + pass + + if result.tree and result.tree.root: + self._integrate_tables_into_tree( + result.tree.root, raw_tables, + content=content, total_pages=page_count, + ) + + print(f"SEARCH_WIKI_DEBUG [C3] table_digest: generated={result.has_table_digest}, count={result.table_count}", flush=True) + del raw_tables + + # --- Phases 2.5-2.8: secondary table extraction (PDF only) --- + # These phases re-read from the PDF file; `content` is only + # needed for Phase 2.6 fallback and Phase 2.8 enrichment. + if ext == ".pdf": + if result.tree and result.tree.root: + targeted_tables = await self._targeted_table_extraction( + entry.path, result.tree, + ) + await self._supplement_table_digest( + entry.path, targeted_tables, result, + source_label="Targeted extraction", + ) + del targeted_tables + + if page_count: + covered_pages = self._get_covered_table_pages(entry.path) + tree_root = ( + result.tree.root + if result.tree and result.tree.root else None + ) + content_tables = await self._content_based_table_scan( + entry.path, page_count, covered_pages, + enhanced_content=content, tree_root=tree_root, + ) + await self._supplement_table_digest( + entry.path, content_tables, result, + source_label="Content-based scan", + ) + del content_tables + + covered_after_scan = self._get_covered_table_pages(entry.path) + gap_pages = self._find_force_ocr_candidates( + entry.path, page_count, covered_after_scan, + ) + if gap_pages: + ocr_tables = await self._selective_force_ocr_tables( + entry.path, gap_pages, + ) + await self._supplement_table_digest( + entry.path, ocr_tables, result, + source_label="Selective force-OCR", + ) + del ocr_tables + + if result.has_table_digest: + self._enrich_table_digest_content( + entry.path, content, tree_root=None, + ) + + # Content is no longer needed — release before returning + del content + + except Exception as exc: + result.error = str(exc) + await self._log.warning(f"[Compile] Failed: {entry.path}: {exc}") + + return result + + @staticmethod + def _is_generic_summary(summary: str, min_specificity_len: int = 80) -> bool: + """Check whether a summary is too generic to be useful for retrieval. + + A generic summary typically contains only structural descriptions + (e.g., "This document contains several sections") without specific + content indicators. Detection uses summary length and information + density as domain-agnostic proxies. + """ + if not summary: + return True + stripped = summary.strip() + if len(stripped) < min_specificity_len: + return True + # Count unique substantive words (>4 chars) as a proxy for specificity + words = set(w.lower() for w in stripped.split() if len(w) > 4) + return len(words) < 8 + + async def _extract_summary( + self, + file_path: str, + content: str, + tree: Optional[DocumentTree] = None, + ) -> str: + """Generate a document-level summary. + + When a tree is available its root already contains an LLM-synthesized + summary (produced by ``_synthesize_root_summary`` during tree build), + so we reuse it directly — unless the summary is too generic (Plan 2), + in which case we fall back to multi-section LLM summarization. + + For large documents without a tree, uses multi-section sampling + (beginning, middle, end) to capture the full scope of the document. + """ + if tree and tree.root and tree.root.summary: + if not self._is_generic_summary(tree.root.summary): + return tree.root.summary + await self._log.info( + f"[Compile] Root summary too generic for {Path(file_path).name}, " + f"falling back to LLM summarization" + ) + + preview = self._build_summary_preview(content) + from sirchmunk.llm.prompts import COMPILE_DOC_SUMMARY + prompt = COMPILE_DOC_SUMMARY.format( + file_name=Path(file_path).name, + document_content=preview, + ) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + return resp.content.strip() + + @staticmethod + def _build_summary_preview(content: str) -> str: + """Build a representative preview for LLM summarisation. + + For short documents (≤ _SUMMARY_PREVIEW_CHARS), returns the full + content. For large documents, samples the beginning, middle, and + end to capture the document's full scope within the token budget. + """ + if len(content) <= _SUMMARY_PREVIEW_CHARS: + return content + + section_size = _SUMMARY_SAMPLE_SECTION_CHARS + mid_start = max(section_size, (len(content) - section_size) // 2) + + head = content[:section_size] + middle = content[mid_start:mid_start + section_size] + tail = content[-section_size:] + + return ( + f"[Beginning of document]\n{head}\n\n" + f"[... content omitted ...]\n\n" + f"[Middle of document]\n{middle}\n\n" + f"[... content omitted ...]\n\n" + f"[End of document]\n{tail}" + ) + + @staticmethod + def _extract_structured_metadata(file_path: str, content: str) -> str: + """Extract structural metadata for non-text document types. + + For spreadsheets and presentations, prepend a structural overview + (sheet names, column headers, slide titles) so the LLM summariser + has better context than raw extracted text alone. + + Returns a metadata prefix string (may be empty for unsupported types). + """ + ext = Path(file_path).suffix.lower() + + if ext == ".xlsx": + metadata, _evidence = KnowledgeCompiler._extract_xlsx_sampling(file_path) + return metadata + if ext == ".pptx": + return KnowledgeCompiler._extract_pptx_metadata(file_path) + + return "" + + @staticmethod + def _compute_xlsx_sample_rows(total_rows: int, num_sheets: int, sheet_rows: int) -> int: + """Compute adaptive sample row count per sheet. + + Strategy: + - Divides _XLSX_TOTAL_ROW_BUDGET equally across sheets + - Small sheets (<=budget) are fully sampled + - Large sheets are capped at budget + - Result clamped to [_XLSX_MIN_ROWS_PER_SHEET, _XLSX_MAX_ROWS_PER_SHEET] + """ + budget_per_sheet = max(1, _XLSX_TOTAL_ROW_BUDGET // max(1, num_sheets)) + n = min(sheet_rows, budget_per_sheet) + return max(_XLSX_MIN_ROWS_PER_SHEET, min(_XLSX_MAX_ROWS_PER_SHEET, n)) + + @staticmethod + def _extract_xlsx_sampling(file_path: str) -> Tuple[str, str]: + """Extract structural metadata AND sampled content from Excel workbook. + + Performs table-level intelligent sampling with adaptive row counts + based on workbook size and sheet complexity. + + Returns: + (metadata_prefix, evidence_digest) + - metadata_prefix: injected into summary generation context + - evidence_digest: structured text usable directly as search evidence + """ + try: + import openpyxl + wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True) + + sheet_names = wb.sheetnames[:_XLSX_MAX_SHEETS] + num_sheets = len(sheet_names) + + # Phase 1: Collect sheet statistics + sheet_stats: List[Dict[str, Any]] = [] + for sheet_name in sheet_names: + ws = wb[sheet_name] + row_count = ws.max_row or 0 + col_count = ws.max_column or 0 + # Read headers (first row) + headers: List[str] = [] + for row in ws.iter_rows(min_row=1, max_row=1, values_only=True): + headers = [str(h) for h in row if h is not None] + break + sheet_stats.append({ + "name": sheet_name, + "rows": row_count, + "cols": col_count, + "headers": headers[:_XLSX_MAX_COLS_DISPLAY], + "ws": ws, + }) + + # Phase 2: Calculate total rows for adaptive sampling + total_rows = sum(s["rows"] for s in sheet_stats) + + meta_lines: List[str] = ["[Excel Workbook Structure]"] + evidence_lines: List[str] = [] + + for stat in sheet_stats: + ws = stat["ws"] + sheet_name = stat["name"] + row_count = stat["rows"] + col_count = stat["cols"] + headers = stat["headers"] + header_str = ", ".join(headers) if headers else "no headers" + + # Metadata line + meta_lines.append( + f"- Sheet '{sheet_name}': {row_count} rows, {col_count} columns, " + f"headers: [{header_str}]" + ) + + # Adaptive sampling + sample_n = KnowledgeCompiler._compute_xlsx_sample_rows( + total_rows, num_sheets, row_count + ) + + evidence_lines.append( + f"[Sheet '{sheet_name}' ({row_count} rows, {col_count} columns)]" + ) + evidence_lines.append(f"Columns: {header_str}") + + # Sample rows + if row_count <= sample_n: + evidence_lines.append(f"(Full content - {row_count} rows)") + else: + evidence_lines.append(f"Sample rows (top {sample_n} of {row_count}):") + + # Build table header + display_headers = headers[:_XLSX_MAX_COLS_DISPLAY] + if display_headers: + evidence_lines.append("| " + " | ".join(display_headers) + " |") + evidence_lines.append("|" + "|".join(["---"] * len(display_headers)) + "|") + + # Read sample rows (skip header row) + numeric_cols: Dict[int, List[float]] = {} # col_index -> numeric values + sampled = 0 + for row in ws.iter_rows( + min_row=2, + max_row=min(row_count, sample_n + 1), + values_only=True, + ): + cells: List[str] = [] + for ci, cell_val in enumerate(row): + if ci >= _XLSX_MAX_COLS_DISPLAY: + break + str_val = str(cell_val) if cell_val is not None else "" + cells.append(str_val[:50]) # truncate long cell values + # Track numeric values for statistics + if isinstance(cell_val, (int, float)) and cell_val == cell_val: + numeric_cols.setdefault(ci, []).append(float(cell_val)) + if cells: + evidence_lines.append("| " + " | ".join(cells) + " |") + sampled += 1 + + # Statistics for numeric columns + stat_parts: List[str] = [] + for ci, values in numeric_cols.items(): + if len(values) >= 2 and ci < len(display_headers): + col_name = display_headers[ci] + stat_parts.append( + f"{col_name} range [{min(values):.4g}-{max(values):.4g}]" + ) + if stat_parts: + evidence_lines.append(f"Statistics: {', '.join(stat_parts[:5])}") + + evidence_lines.append("") # blank line between sheets + + wb.close() + + metadata = "\n".join(meta_lines) + "\n\n" + evidence = "\n".join(evidence_lines) + return metadata, evidence + + except Exception: + return "", "" + + @staticmethod + def _extract_xlsx_metadata(file_path: str) -> str: + """Extract structural metadata from Excel files (legacy wrapper). + + Delegates to _extract_xlsx_sampling and returns only the metadata prefix + for backward compatibility. + """ + metadata, _evidence = KnowledgeCompiler._extract_xlsx_sampling(file_path) + return metadata + + @staticmethod + def _extract_pptx_metadata(file_path: str) -> str: + """Extract structural metadata from PowerPoint files. + + Reads slide count and titles (from the title placeholder) to give + the LLM a table-of-contents-like overview of the presentation. + Caps at 20 slides for bounded output. + """ + try: + from pptx import Presentation + prs = Presentation(file_path) + lines: List[str] = [f"[PowerPoint Structure: {len(prs.slides)} slides]"] + for i, slide in enumerate(prs.slides[:20], 1): # Cap at 20 slides + title = "" + if slide.shapes.title: + title = slide.shapes.title.text.strip() + if title: + lines.append(f"- Slide {i}: {title}") + return "\n".join(lines) + "\n\n" + except Exception: + return "" + + def _build_evidence( + self, + entry: FileEntry, + content: str, + result: FileCompileResult, + ) -> EvidenceUnit: + """Build an EvidenceUnit, populating snippets/tree_path from tree leaves.""" + from sirchmunk.schema.metadata import FileInfo + + snippets: List[str] = [] + tree_path: Optional[List[str]] = None + + if result.tree and result.tree.root: + leaves = result.tree.root.all_leaves() + tree_path = [leaf.node_id for leaf in leaves] + for leaf in leaves: + start, end = leaf.char_range + snippet = content[start:end][:500] + if snippet.strip(): + snippets.append(snippet) + + return EvidenceUnit( + doc_id=FileInfo.get_cache_key(entry.path), + file_or_url=Path(entry.path), + summary=result.summary, + is_found=True, + snippets=snippets, + tree_path=tree_path, + extracted_at=datetime.now(timezone.utc), + ) + + async def _extract_topics(self, summary: str) -> List[str]: + """Extract key topics/entities from a document summary.""" + from sirchmunk.llm.prompts import COMPILE_TOPIC_EXTRACTION + prompt = COMPILE_TOPIC_EXTRACTION.format(summary=summary) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + try: + raw = resp.content.strip() + if raw.startswith("["): + parsed = json.loads(raw) + if isinstance(parsed, list): + return [str(t) for t in parsed if t] + return [t.strip() for t in raw.split(",") if t.strip()] + except (json.JSONDecodeError, TypeError): + return [] + + # ------------------------------------------------------------------ # + # Knowledge aggregation (LLM Wiki Ingest) # + # ------------------------------------------------------------------ # + + async def _aggregate_to_knowledge_network( + self, result: FileCompileResult, + ) -> Tuple[int, int]: + """Aggregate a file's compile result into the knowledge network. + + Three-tier similarity strategy (per design doc): + - similarity >= 0.80 → merge into existing cluster + - 0.50 <= sim < 0.80 → create new cluster + weak edge to similar + - similarity < 0.50 → create standalone cluster + + Returns: + (clusters_created, clusters_merged) + """ + created, merged = 0, 0 + if not result.summary: + return created, merged + + embedding = self._encode_text(result.summary) + + # Search for similar existing clusters across a wider range + best_match: Optional[Dict[str, Any]] = None + if embedding is not None: + similar = await self._storage.search_similar_clusters( + query_embedding=embedding, + top_k=3, + similarity_threshold=0.50, + ) + if similar: + best_match = similar[0] + + if best_match and best_match["similarity"] >= 0.80: + # Tier 1: merge into existing cluster + cluster = await self._storage.get(best_match["id"]) + if cluster: + await self._merge_into_cluster(cluster, result) + # Re-compute embedding for merged content + await self._update_cluster_embedding(cluster) + result.cluster_ids.append(cluster.id) + merged += 1 + return created, merged + + # Create a new cluster (Tier 2 or Tier 3) + cluster = await self._create_cluster(result) + if cluster: + result.cluster_ids.append(cluster.id) + await self._store_cluster_embedding(cluster, embedding, result.summary) + created += 1 + + # Tier 2: build weak edges to moderately similar clusters + if best_match and best_match["similarity"] >= 0.50: + for s in (similar or []): + if s["similarity"] >= 0.50: + target = await self._storage.get(s["id"]) + if target: + self._add_edge(cluster, target.id, "embed_sim", s["similarity"]) + self._add_edge(target, cluster.id, "embed_sim", s["similarity"]) + await self._storage.update(target) + await self._storage.update(cluster) + + return created, merged + + def _encode_text(self, text: str) -> Optional[Any]: + """Encode text to embedding vector, returns None on failure.""" + if not self._embedding or not self._embedding.is_ready(): + return None + try: + vectors = self._embedding._encode_sync([text]) + return vectors[0] if len(vectors) > 0 else None + except Exception: + return None + + async def _store_cluster_embedding( + self, cluster: KnowledgeCluster, embedding: Optional[Any], text: str, + ) -> None: + """Store embedding for a cluster if available.""" + if embedding is None or not self._embedding: + return + text_hash = hashlib.md5(text.encode()).hexdigest() + vec = embedding.tolist() if hasattr(embedding, "tolist") else list(embedding) + await self._storage.store_embedding( + cluster.id, vec, + self._embedding.model_id or "default", + text_hash, + ) + + async def _update_cluster_embedding(self, cluster: KnowledgeCluster) -> None: + """Re-compute and store embedding after content merge.""" + content_text = str(cluster.content)[:2000] if cluster.content else "" + if not content_text: + return + embedding = self._encode_text(content_text) + await self._store_cluster_embedding(cluster, embedding, content_text) + + async def _merge_into_cluster( + self, + cluster: KnowledgeCluster, + result: FileCompileResult, + ) -> None: + """Merge a file compile result into an existing cluster.""" + # Append evidence + if result.evidence: + existing_doc_ids = {e.doc_id for e in cluster.evidences} + if result.evidence.doc_id not in existing_doc_ids: + cluster.evidences.append(result.evidence) + + # Enrich content via LLM merge + from sirchmunk.llm.prompts import COMPILE_MERGE_KNOWLEDGE + prompt = COMPILE_MERGE_KNOWLEDGE.format( + existing_content=str(cluster.content)[:3000], + new_summary=result.summary[:3000], + ) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + cluster.content = resp.content.strip() + + # Update metadata + cluster.search_results = list(set( + (cluster.search_results or []) + [result.path] + )) + merge_count = getattr(cluster, "merge_count", 0) or 0 + cluster.merge_count = merge_count + 1 + + # Lifecycle promotion + if cluster.merge_count >= 3 and cluster.lifecycle == Lifecycle.EMERGING: + cluster.lifecycle = Lifecycle.STABLE + + await self._storage.update(cluster) + + async def _create_cluster( + self, result: FileCompileResult, + ) -> Optional[KnowledgeCluster]: + """Create a new KnowledgeCluster from a file compile result.""" + cluster_text = result.summary + cluster_id = f"C{hashlib.sha256(cluster_text.encode('utf-8')).hexdigest()[:10]}" + + name = Path(result.path).stem[:60] + if result.topics: + name = result.topics[0][:60] + + cluster = KnowledgeCluster( + id=cluster_id, + name=name, + description=[result.summary[:500]], + content=result.summary, + evidences=[result.evidence] if result.evidence else [], + patterns=result.topics[:5], + lifecycle=Lifecycle.EMERGING, + confidence=0.5, + abstraction_level=AbstractionLevel.TECHNIQUE, + hotness=0.3, + search_results=[result.path], + ) + + ok = await self._storage.insert(cluster) + return cluster if ok else None + + # ------------------------------------------------------------------ # + # Cross-references # + # ------------------------------------------------------------------ # + + async def _build_cross_references_from_pairs( + self, + pairs: List[Tuple[str, List[str]]], + manifest: CompileManifest, + ) -> int: + """Build co-occurrence edges between clusters that share source files. + + Accepts lightweight ``(path, cluster_ids)`` pairs instead of full + ``FileCompileResult`` objects to avoid retaining heavy compile results. + Includes historical data from the manifest. + """ + cluster_to_files: Dict[str, Set[str]] = {} + + for path, cluster_ids in pairs: + for cid in cluster_ids: + cluster_to_files.setdefault(cid, set()).add(path) + + for fp, entry in manifest.files.items(): + for cid in entry.cluster_ids: + cluster_to_files.setdefault(cid, set()).add(fp) + + # Find cluster pairs that share at least one source file + cluster_ids = list(cluster_to_files.keys()) + edges_created = 0 + pairs_seen: Set[Tuple[str, str]] = set() + + for i in range(len(cluster_ids)): + for j in range(i + 1, len(cluster_ids)): + cid_a, cid_b = cluster_ids[i], cluster_ids[j] + shared = cluster_to_files[cid_a] & cluster_to_files[cid_b] + if not shared: + continue + + pair_key = (min(cid_a, cid_b), max(cid_a, cid_b)) + if pair_key in pairs_seen: + continue + pairs_seen.add(pair_key) + + weight = min(len(shared) * 0.25, 1.0) + c_a = await self._storage.get(cid_a) + c_b = await self._storage.get(cid_b) + if c_a and c_b: + self._add_edge(c_a, cid_b, "co_occur", weight) + self._add_edge(c_b, cid_a, "co_occur", weight) + await self._storage.update(c_a) + await self._storage.update(c_b) + edges_created += 1 + + return edges_created + + @staticmethod + def _add_edge( + cluster: KnowledgeCluster, target_id: str, source: str, weight: float, + ) -> None: + """Add or update a WeakSemanticEdge on a cluster.""" + for edge in cluster.related_clusters: + if edge.target_cluster_id == target_id and edge.source == source: + edge.weight = max(edge.weight, weight) + return + cluster.related_clusters.append( + WeakSemanticEdge(target_cluster_id=target_id, weight=weight, source=source) + ) + + def _build_table_digest( + self, tables: List[Dict[str, Any]], + ) -> Optional[Dict[str, Any]]: + """Build a structured table digest from extraction output. + + Returns a versioned JSON-serializable dict containing all tables + with their page numbers, markdown representation, and cell data. + Tables are indexed for page-range-based retrieval at search time. + """ + if not tables: + return None + + digest_tables = [] + for idx, table in enumerate(tables): + markdown = table.get("markdown", "") + cells = table.get("cells", []) + if not markdown and not cells: + continue + + # Compute row/col counts from cells (kreuzberg returns List[List[str]]) + row_count = 0 + col_count = 0 + if cells: + row_count = len(cells) + col_count = max((len(row) for row in cells if isinstance(row, (list, tuple))), default=0) + elif markdown: + # Estimate from markdown lines + lines = [l for l in markdown.strip().split("\n") if l.strip().startswith("|")] + row_count = max(0, len(lines) - 1) # exclude separator + col_count = lines[0].count("|") - 1 if lines else 0 + + # Skip pseudo-tables: single-column or insufficient structure + if col_count <= 1: + continue + + digest_tables.append({ + "index": idx, + "page_number": table.get("page_number"), + "markdown": markdown, + "row_count": row_count, + "col_count": col_count, + "cells": cells, + }) + + if not digest_tables: + return None + + return { + "version": 1, + "table_count": len(digest_tables), + "tables": digest_tables, + } + + def _integrate_tables_into_tree( + self, + node: "TreeNode", + tables: List[Dict[str, Any]], + content: str, + *, + total_pages: Optional[int] = None, + _counter: Optional[List[int]] = None, + ) -> None: + """Integrate tables into tree: annotate counts AND create table child nodes for leaf nodes. + + For each node with a valid page_range, counts how many valid extracted + tables fall within that range (excluding pseudo-tables with col_count <= 1). + For leaf nodes with matching tables, creates dedicated TreeNode children + with ``content_type="table"``. + """ + from sirchmunk.learnings.tree_indexer import TreeNode + + if node is None: + return + + if _counter is None: + _counter = [0] + + # Depth-first: process existing children first + for child in list(node.children): + self._integrate_tables_into_tree( + child, tables, content, + total_pages=total_pages, _counter=_counter, + ) + + # Match valid tables to this node's page_range + matched_tables: List[Dict[str, Any]] = [] + if node.page_range: + ps, pe = node.page_range + for t in tables: + pn = t.get("page_number") + if pn is None or not (ps <= pn <= pe): + continue + # Skip pseudo-tables + if self._is_pseudo_table(t): + continue + matched_tables.append(t) + + node.table_count = len(matched_tables) + + # NOTE: _spawn_table_children disabled - converting leaf to non-leaf breaks + # search navigation which expects leaves for char_range extraction. + # TODO: Re-enable when search can properly handle mixed text+table children. + # if not node.children and matched_tables: + # try: + # self._spawn_table_children( + # node, matched_tables, content, _counter, + # ) + # except Exception: + # pass + + @staticmethod + def _is_pseudo_table(table: Dict[str, Any]) -> bool: + """Return True if the table lacks meaningful structure (col_count <= 1).""" + markdown = table.get("markdown", "") + cells = table.get("cells", []) + if not markdown and not cells: + return True + col_count = 0 + if cells: + col_count = max( + (len(row) for row in cells if isinstance(row, (list, tuple))), + default=0, + ) + elif markdown: + lines = [l for l in markdown.strip().split("\n") if l.strip().startswith("|")] + col_count = (lines[0].count("|") - 1) if lines else 0 + return col_count <= 1 + + def _spawn_table_children( + self, + node: "TreeNode", + matched_tables: List[Dict[str, Any]], + content: str, + counter: List[int], + ) -> None: + """Create TreeNode children for each matched table under a leaf node. + + Also inserts a text-content sibling preserving the original leaf content. + """ + from sirchmunk.learnings.tree_indexer import TreeNode + + child_level = node.level + 1 + + # Preserve original text content as first child + text_child_id = f"T{counter[0]:06d}" + counter[0] += 1 + node.children.append( + TreeNode( + node_id=text_child_id, + title=node.title, + summary=node.summary[:300] if node.summary else "", + char_range=node.char_range, + level=child_level, + page_range=node.page_range, + children=[], + table_count=0, + content_type="text", + ) + ) + + # Create one child per table + for table in matched_tables: + tid = f"T{counter[0]:06d}" + counter[0] += 1 + + markdown = table.get("markdown", "") + title = self._extract_table_title(table) + page_number = table.get("page_number") + + # Attempt to locate table markdown in content + char_range = node.char_range + if markdown and content: + pos = content.find(markdown[:120]) + if pos >= 0: + char_range = (pos, pos + len(markdown)) + + page_range = ( + (page_number, page_number) if page_number is not None + else node.page_range + ) + + node.children.append( + TreeNode( + node_id=tid, + title=title, + summary=markdown[:300] if markdown else "", + char_range=char_range, + level=child_level, + page_range=page_range, + children=[], + table_count=0, + content_type="table", + ) + ) + + @staticmethod + def _extract_table_title(table: Dict[str, Any]) -> str: + """Extract a concise title from table markdown header row. + + Parses the first meaningful line of the markdown table (skipping + separator rows like ``|---|---|``), strips ``|`` delimiters, and + returns the first 80 characters as the title. + """ + markdown = table.get("markdown", "") + if not markdown: + pn = table.get("page_number", "?") + return f"Table (p.{pn})" + + for line in markdown.strip().split("\n"): + stripped = line.strip() + if not stripped: + continue + # Skip separator rows (e.g. |---|---| or +---+---+) + content_chars = stripped.replace("|", "").replace("-", "").replace(":", "").replace("+", "").strip() + if not content_chars: + continue + # Extract cell contents + title = " | ".join( + seg.strip() for seg in stripped.split("|") if seg.strip() + ) + return title[:80] if title else f"Table (p.{table.get('page_number', '?')})" + + pn = table.get("page_number", "?") + return f"Table (p.{pn})" + + @staticmethod + def _count_tree_nodes(tree: Optional[DocumentTree]) -> int: + """Count total nodes in a DocumentTree (recursive). + + Args: + tree: The tree to count, or None. + + Returns: + Total node count, or 0 if tree is None. + """ + if tree is None or tree.root is None: + return 0 + + def _count(node: Any) -> int: + return 1 + sum(_count(c) for c in node.children) + + return _count(tree.root) + + # ------------------------------------------------------------------ # + # Targeted table extraction # + # ------------------------------------------------------------------ # + + async def _targeted_table_extraction( + self, file_path: str, tree: DocumentTree, + ) -> list[dict]: + """Extract tables from tree nodes likely containing tabular data. + + Uses generic structural signals (metadata, page span, numeric + density) instead of domain-specific title keywords. For each + candidate with a valid ``page_range``, extracts per-page text + via :meth:`DocumentExtractor.extract_page_range` and applies + heuristic table-region detection. Pages whose numeric density + falls below ``_TABLE_NUMERIC_DENSITY_THRESHOLD`` are skipped. + + Returns: + List of table dicts compatible with the table-digest format:: + + {"page": int, "content": str, "source": str} + """ + if tree is None or tree.root is None: + return [] + + candidates = self._find_table_candidate_nodes(tree.root) + if not candidates: + return [] + + await self._log.info( + f"[Compile] Targeted extraction: {len(candidates)} candidate " + f"nodes in {Path(file_path).name}" + ) + + results: list[dict] = [] + seen_pages: set[int] = set() + + for node in candidates: + if node.page_range is None: + continue + start_page, end_page = node.page_range + # Skip pages already processed by another candidate + page_nums = [p for p in range(start_page, end_page + 1) + if p not in seen_pages] + if not page_nums: + continue + + try: + pages = DocumentExtractor.extract_page_range( + file_path, start_page, end_page, + ) + except Exception as exc: + await self._log.warning( + f"[Compile] Targeted extraction page read failed " + f"({start_page}-{end_page}): {exc}" + ) + continue + + for pc in pages: + if pc.page_number in seen_pages: + continue + seen_pages.add(pc.page_number) + # Numeric density gate – skip pages unlikely to contain tables + if not self._page_has_table_density(pc.content): + continue + regions = self._identify_table_regions(pc.content) + for region in regions: + truncated = region[:_TARGETED_TABLE_MAX_CHARS] + results.append({ + "page": pc.page_number, + "content": truncated, + "source": f"targeted:{node.title[:80]}", + }) + + return results + + def _find_table_candidate_nodes( + self, root: "TreeNode", + ) -> list["TreeNode"]: + """Collect leaf nodes that likely contain tables. + + Uses generic, domain-agnostic structural signals (any match + suffices): + + - ``node.content_type == "table"`` – already tagged during compile. + - ``node.table_count > 0`` – known to contain tables. + - Has a valid ``page_range`` with span ≤ ``_TABLE_PAGE_SPAN_LIMIT``. + """ + candidates: list = [] + + def _walk(node: "TreeNode") -> None: + if node.leaf: + # Signal 1: content_type marked as table + if getattr(node, "content_type", None) == "table": + candidates.append(node) + return + # Signal 2: known to contain tables + if getattr(node, "table_count", 0) > 0: + candidates.append(node) + return + # Signal 3: moderate page span (tables rarely span many pages) + page_range = getattr(node, "page_range", None) + if page_range and len(page_range) == 2: + span = page_range[1] - page_range[0] + 1 + if 1 <= span <= _TABLE_PAGE_SPAN_LIMIT: + candidates.append(node) + else: + for child in node.children: + _walk(child) + + _walk(root) + return candidates + + # ------------------------------------------------------------------ # + # LLM-based heading normalisation # + # ------------------------------------------------------------------ # + + @staticmethod + def _extract_heading_candidates( + content: str, + ) -> list[tuple[re.Match, str, str]]: + """Extract candidate lines that *might* be section headings. + + Returns a list of ``(match, title_text, source_tag)`` triples + where *source_tag* is ``"bold"`` or ``"standalone"``. + + Bold lines (``**Title**``) are always candidates. Short + standalone lines (surrounded by blank lines, 10-100 chars) are + included only when they pass structural heuristics that filter + out data rows, sentences, and existing headings. + """ + occupied: list[tuple[int, int]] = [] + candidates: list[tuple[re.Match, str, str]] = [] + + def _overlaps(start: int, end: int) -> bool: + return any(s < end and start < e for s, e in occupied) + + for m in _BOLD_LINE_RE.finditer(content): + title = m.group(1).strip() + if title and not _overlaps(m.start(), m.end()): + occupied.append((m.start(), m.end())) + candidates.append((m, title, "bold")) + + for m in _STANDALONE_LINE_RE.finditer(content): + text = m.group(1).strip() + if len(text) < 10: + continue + text_offset = m.start() + m.group(0).index(m.group(1)) + if _overlaps(text_offset, text_offset + len(m.group(1))): + continue + if text.startswith(("#", "**")): + continue + if _NUM_TOKEN_RE.search(text): + continue + if text.endswith((".", "。", "!", "?", "!", "?")): + continue + if len(text.split()) > 12: + continue + occupied.append((text_offset, text_offset + len(m.group(1)))) + candidates.append((m, text, "standalone")) + + candidates.sort(key=lambda t: t[0].start()) + return candidates[:_HEADING_CANDIDATE_CAP] + + async def _normalize_bold_headings(self, content: str) -> str: + """Detect and promote bold/standalone section titles to headings. + + Three-phase pipeline: + 1. **Extract** candidate lines via regex (deterministic). + 2. **Classify** candidates with a single LLM call — the LLM + returns which indices are section headings and their level. + 3. **Replace** confirmed headings deterministically. + + Short-circuits when no candidates are found (zero LLM calls). + On any LLM / parse failure, returns the original content unchanged + (graceful degradation — equivalent to no-op). + + The transformation is idempotent: existing ``#`` headings never + enter the candidate set. + """ + if not content: + return content + + candidates = self._extract_heading_candidates(content) + if not candidates: + return content + + listing = "\n".join( + f"{i}: \"{title}\"" for i, (_, title, _tag) in enumerate(candidates) + ) + + from sirchmunk.llm.prompts import COMPILE_CLASSIFY_HEADINGS + prompt = COMPILE_CLASSIFY_HEADINGS.format(candidates=listing) + + try: + resp = await self._llm.achat( + [{"role": "user", "content": prompt}], + ) + raw = resp.content.strip() + headings = self._parse_heading_classifications(raw, len(candidates)) + except Exception: + return content + + if not headings: + return content + + return self._apply_heading_promotions(content, candidates, headings) + + @staticmethod + def _parse_heading_classifications( + raw: str, + num_candidates: int, + ) -> list[tuple[int, int]]: + """Parse LLM JSON response into a list of ``(idx, level)`` pairs. + + Robustly handles markdown code fences, trailing commas, and + out-of-range indices. Returns an empty list on any parse failure. + """ + cleaned = raw.strip() + if cleaned.startswith("```"): + lines = cleaned.splitlines() + lines = [ln for ln in lines if not ln.strip().startswith("```")] + cleaned = "\n".join(lines).strip() + + try: + items = json.loads(cleaned) + except json.JSONDecodeError: + m = re.search(r"\[.*\]", cleaned, re.DOTALL) + if not m: + return [] + try: + items = json.loads(m.group()) + except json.JSONDecodeError: + return [] + + if not isinstance(items, list): + return [] + + result: list[tuple[int, int]] = [] + for item in items: + if isinstance(item, dict): + idx = item.get("idx") + level = item.get("level", 2) + elif isinstance(item, int): + idx, level = item, 2 + else: + continue + if not isinstance(idx, int) or not (0 <= idx < num_candidates): + continue + level = max(2, min(4, int(level))) + result.append((idx, level)) + return result + + @staticmethod + def _apply_heading_promotions( + content: str, + candidates: list[tuple[re.Match, str, str]], + headings: list[tuple[int, int]], + ) -> str: + """Apply heading promotions to *content* in reverse-offset order. + + Processes replacements from end-to-start so that earlier offsets + remain valid after each substitution. + """ + heading_map: dict[int, int] = dict(headings) + + replacements: list[tuple[int, int, str]] = [] + for idx, (match, title, tag) in enumerate(candidates): + if idx not in heading_map: + continue + level = heading_map[idx] + prefix = "#" * level + if tag == "bold": + replacements.append((match.start(), match.end(), f"{prefix} {title}")) + else: + text_start = match.start() + match.group(0).index(match.group(1)) + text_end = text_start + len(match.group(1)) + replacements.append((text_start, text_end, f"{prefix} {title}")) + + replacements.sort(key=lambda r: r[0], reverse=True) + for start, end, replacement in replacements: + content = content[:start] + replacement + content[end:] + return content + + @staticmethod + def _page_has_table_density(page_text: str) -> bool: + """Return True if *page_text* likely contains tabular numeric data. + + Two independent signals (either suffices): + + 1. **Character-level density** — fraction of digit/symbol chars + relative to total non-whitespace exceeds the threshold. + 2. **Token-dense line** — any single line contains + ``_DENSE_LINE_MIN_TOKENS`` or more numeric tokens, which + catches pages where pypdf flattens all content into ≤ 2 lines. + """ + if not page_text: + return False + non_ws = sum(1 for ch in page_text if not ch.isspace()) + if non_ws == 0: + return False + numeric_chars = sum( + 1 for ch in page_text + if ch.isdigit() or ch in "$%(),.+-" + ) + if (numeric_chars / non_ws) >= _TABLE_NUMERIC_DENSITY_THRESHOLD: + return True + return any( + len(_NUM_TOKEN_RE.findall(line)) >= _DENSE_LINE_MIN_TOKENS + for line in page_text.split("\n") + ) + + @staticmethod + def _identify_table_regions(page_text: str) -> list[str]: + """Identify contiguous table-like regions in *page_text*. + + Two complementary strategies: + + 1. **Consecutive-line detection** — a run of ≥ 3 lines each + containing ≥ 2 numeric tokens forms a table region. Works + well when pypdf preserves per-row line breaks. + 2. **Dense-line detection** — a *single* line with ≥ + ``_DENSE_LINE_MIN_TOKENS`` numeric tokens is treated as a + table region. This handles PDFs where pypdf collapses + the entire page into one or two very long lines. + + Returns: + List of extracted region strings (may be empty). + """ + if not page_text: + return [] + + _MIN_NUMS_PER_LINE = 2 + _MIN_CONSECUTIVE = 3 + + lines = page_text.split("\n") + token_counts = [ + len(_NUM_TOKEN_RE.findall(line)) for line in lines + ] + + regions: list[str] = [] + captured_lines: set[int] = set() + + # --- Strategy 1: consecutive-line runs --- + run_start: int | None = None + for i, cnt in enumerate(token_counts): + if cnt >= _MIN_NUMS_PER_LINE: + if run_start is None: + run_start = i + else: + if run_start is not None: + if i - run_start >= _MIN_CONSECUTIVE: + start = max(0, run_start - 1) + end = min(len(lines), i + 1) + regions.append( + "\n".join(lines[start:end]).strip() + ) + captured_lines.update(range(start, end)) + run_start = None + if run_start is not None and len(lines) - run_start >= _MIN_CONSECUTIVE: + start = max(0, run_start - 1) + regions.append("\n".join(lines[start:]).strip()) + captured_lines.update(range(start, len(lines))) + + # --- Strategy 2: dense-line detection --- + for i, cnt in enumerate(token_counts): + if cnt >= _DENSE_LINE_MIN_TOKENS and i not in captured_lines: + start = max(0, i - 1) + end = min(len(lines), i + 2) + regions.append("\n".join(lines[start:end]).strip()) + + return regions + + @staticmethod + def _get_table_page(entry: dict) -> int | None: + """统一获取表格条目的页码,兼容 page_number 和 page 两种字段名。""" + p = entry.get("page_number") or entry.get("page") + return int(p) if p is not None else None + + @classmethod + def _merge_table_digests( + cls, existing: list[dict], new_tables: list[dict], + ) -> list[dict]: + """Merge *new_tables* into *existing* digest, deduplicating by page. + + If an existing entry and a new entry share the same page number, + the new entry is skipped (existing kreuzberg-detected table takes + precedence because it has richer structure like cells/markdown). + + Returns: + Merged list suitable for storage in the table-digest JSON. + """ + existing_pages = {cls._get_table_page(e) for e in existing} + existing_pages.discard(None) + + merged = list(existing) + for tbl in new_tables: + page = cls._get_table_page(tbl) + if page is not None and page in existing_pages: + continue + merged.append({ + "page_number": page, + "markdown": tbl.get("markdown", "") or tbl.get("content", ""), + "row_count": tbl.get("row_count"), + "col_count": tbl.get("col_count"), + "cells": tbl.get("cells", []), + "source": tbl.get("source", "supplementary"), + }) + return merged + + async def _supplement_table_digest( + self, + file_path: str, + new_tables: list[dict], + result: "FileCompileResult", + *, + source_label: str, + ) -> None: + """Merge supplementary tables into the persisted table digest. + + Loads the existing digest (if any), merges *new_tables* with + page-level deduplication, and writes the updated digest back. + Updates *result* metadata in place. + """ + if not new_tables: + return + + file_hash = get_fast_hash(file_path) or "" + if not file_hash: + return + + digest_dir = self._compile_dir / "table_digests" + digest_path = digest_dir / f"{file_hash}.json" + + existing: list[dict] = [] + if result.has_table_digest and digest_path.exists(): + try: + raw = json.loads(digest_path.read_text(encoding="utf-8")) + existing = raw.get("tables", []) + except Exception: + pass + + merged = self._merge_table_digests(existing, new_tables) + if not merged: + return + + digest_dir.mkdir(parents=True, exist_ok=True) + digest_path.write_text( + json.dumps( + {"version": 1, "table_count": len(merged), "tables": merged}, + ensure_ascii=False, + ), + encoding="utf-8", + ) + result.has_table_digest = True + result.table_count = len(merged) + await self._log.info( + f"[Compile] {source_label}: +{len(new_tables)} tables for " + f"{Path(file_path).name} (total={len(merged)})" + ) + + def _get_covered_table_pages(self, file_path: str) -> Set[int]: + """Return the set of page numbers already present in the table digest.""" + file_hash = get_fast_hash(file_path) or "" + if not file_hash: + return set() + + digest_path = ( + self._compile_dir / "table_digests" / f"{file_hash}.json" + ) + if not digest_path.exists(): + return set() + + try: + raw = json.loads(digest_path.read_text(encoding="utf-8")) + pages: Set[int] = set() + for t in raw.get("tables", []): + p = self._get_table_page(t) + if p is not None: + pages.add(p) + return pages + except Exception: + return set() + + # ------------------------------------------------------------------ # + # P1: Enrich table digest with ENHANCED content # + # ------------------------------------------------------------------ # + + @staticmethod + def _build_page_char_map( + tree_root: Any, + max_page_span: int = _TABLE_PAGE_SPAN_LIMIT, + ) -> Dict[int, Tuple[int, int]]: + """Map page numbers to ``(start_char, end_char)`` in ENHANCED content. + + Aggregates ``char_range`` bounds from leaf nodes whose + ``page_range`` intersects a given page. To avoid inflated + ranges from wide-spanning nodes (e.g. a cover-page node + spanning pages 1–85), only nodes with a page span ≤ + *max_page_span* are used when available; wider nodes serve + as a fallback. + """ + # (char_start, char_end, page_span) per page + entries: Dict[int, List[Tuple[int, int, int]]] = {} + + def _walk(node: Any) -> None: + children = getattr(node, "children", None) or [] + if isinstance(node, dict): + children = node.get("children", []) + if not children: + pr = ( + getattr(node, "page_range", None) + if not isinstance(node, dict) + else node.get("page_range") + ) + cr = ( + getattr(node, "char_range", None) + if not isinstance(node, dict) + else node.get("char_range") + ) + if ( + pr + and cr + and len(pr) >= 2 + and len(cr) >= 2 + ): + span = int(pr[1]) - int(pr[0]) + 1 + for p in range(int(pr[0]), int(pr[1]) + 1): + entries.setdefault(p, []).append( + (int(cr[0]), int(cr[1]), span) + ) + for ch in children: + _walk(ch) + + _walk(tree_root) + + result: Dict[int, Tuple[int, int]] = {} + for page, elist in entries.items(): + narrow = [e for e in elist if e[2] <= max_page_span] + chosen = narrow if narrow else elist + result[page] = ( + min(e[0] for e in chosen), + max(e[1] for e in chosen), + ) + return result + + @staticmethod + def _find_enhanced_region( + enhanced_content: str, + pypdf_text: str, + budget: int = _TARGETED_TABLE_MAX_CHARS, + ) -> Optional[str]: + """Locate the ENHANCED content region matching *pypdf_text*. + + Uses progressively shorter text anchors extracted from the + pypdf content to find the corresponding position in the + ENHANCED (kreuzberg markdown) text. Whitespace is normalised + in the anchor to handle formatting differences (pypdf line + breaks vs kreuzberg markdown spacing). This avoids reliance + on page-number alignment, which may differ between the two + extractors. + + Returns the ENHANCED slice (up to *budget* chars) or ``None``. + """ + text = pypdf_text.strip() + for prefix in ("Table of Contents\n", "Table of Contents "): + if text.startswith(prefix): + text = text[len(prefix):] + text = text.strip() + + for anchor_len in (80, 50, 30): + raw = text[:anchor_len].strip() + if len(raw) < 15: + continue + anchor = " ".join(raw.split()) + pos = enhanced_content.find(anchor) + if pos < 0: + continue + start = max( + 0, + enhanced_content.rfind("\n", max(0, pos - 300), pos) + 1, + ) + end = min(len(enhanced_content), start + budget) + return enhanced_content[start:end].strip() + + return None + + def _enrich_table_digest_content( + self, + file_path: str, + enhanced_content: str, + tree_root: Optional[Any], + ) -> None: + """Replace pypdf-sourced table text with ENHANCED content slices. + + Targeted extraction tables use pypdf, which often produces dense + single-line text (the "2-line page" problem). This method + locates each table's content in the ENHANCED (kreuzberg markdown) + text via anchor matching and replaces the ``markdown`` field when + the ENHANCED version has substantially better structure. + + Only tables whose ``source`` indicates pypdf origin are + candidates; kreuzberg-detected tables already have high-quality + markdown and are left untouched. + """ + if not enhanced_content: + return + + file_hash = get_fast_hash(file_path) or "" + if not file_hash: + return + + digest_path = ( + self._compile_dir / "table_digests" / f"{file_hash}.json" + ) + if not digest_path.exists(): + return + + try: + raw = json.loads(digest_path.read_text(encoding="utf-8")) + tables = raw.get("tables", []) + except Exception: + return + + if not tables: + return + + modified = False + for table in tables: + source = table.get("source", "") + if not ( + source.startswith("targeted:") + or source == "content_scan" + ): + continue + + current = table.get("markdown", "") + if not current: + continue + + enhanced_region = self._find_enhanced_region( + enhanced_content, current, + ) + if not enhanced_region: + continue + + current_lines = len(current.strip().split("\n")) + enhanced_lines = len(enhanced_region.split("\n")) + + if enhanced_lines > max(current_lines, 3): + table["markdown"] = enhanced_region[ + :_TARGETED_TABLE_MAX_CHARS + ] + modified = True + + if modified: + digest_path.write_text( + json.dumps(raw, ensure_ascii=False), + encoding="utf-8", + ) + + # ------------------------------------------------------------------ # + # Tree-independent content-based table scanning # + # ------------------------------------------------------------------ # + + async def _content_based_table_scan( + self, + file_path: str, + total_pages: Optional[int], + covered_pages: Set[int], + *, + enhanced_content: Optional[str] = None, + tree_root: Optional[Any] = None, + ) -> list[dict]: + """Scan PDF pages for table-like regions via numeric density. + + Uses a two-tier strategy: + + 1. **pypdf page scan** — reads every page individually. Works well + when pypdf preserves per-row line breaks. + 2. **ENHANCED content fallback** — if pypdf yields poor line + structure (> 50 % of pages have ≤ 3 lines), falls back to + scanning the kreuzberg ENHANCED markdown content, which often + has better formatting. Page numbers are recovered via the + tree's ``char_range → page_range`` mapping. + + Args: + file_path: Path to the PDF file. + total_pages: Total page count. + covered_pages: Page numbers already in the table digest. + enhanced_content: Cached kreuzberg ENHANCED text (optional). + tree_root: Tree root node for char → page mapping (optional). + + Returns: + List of table dicts compatible with the digest format. + """ + if not total_pages or total_pages <= 0: + return [] + + results = await self._pypdf_page_scan( + file_path, total_pages, covered_pages, + ) + + if results or not enhanced_content or not tree_root: + return results + + return self._enhanced_content_scan( + enhanced_content, total_pages, covered_pages, tree_root, + ) + + async def _pypdf_page_scan( + self, + file_path: str, + total_pages: int, + covered_pages: Set[int], + ) -> list[dict]: + """Primary scan: per-page pypdf extraction with density heuristics. + + Pages are loaded in batches of ``_PAGE_SCAN_BATCH_SIZE`` to bound + peak memory when processing large PDFs (200-400+ pages). + """ + results: list[dict] = [] + poor_line_count = 0 + + for batch_start in range(1, total_pages + 1, _PAGE_SCAN_BATCH_SIZE): + batch_end = min(batch_start + _PAGE_SCAN_BATCH_SIZE, total_pages + 1) + batch_pages = list(range(batch_start, batch_end)) + try: + pages = DocumentExtractor.extract_pages(file_path, batch_pages) + except Exception as exc: + await self._log.warning( + f"[Compile] Content-based scan: page read failed for " + f"{Path(file_path).name}: {exc}" + ) + return [] + + for pc in pages: + if len(pc.content.split("\n")) <= 3: + poor_line_count += 1 + if pc.page_number in covered_pages: + continue + if not self._page_has_table_density(pc.content): + continue + for region in self._identify_table_regions(pc.content): + results.append({ + "page": pc.page_number, + "content": region[:_TARGETED_TABLE_MAX_CHARS], + "source": "content_scan", + }) + del pages + + if results: + return results + + if poor_line_count > total_pages * 0.5: + return [] + + return results + + @staticmethod + def _enhanced_content_scan( + enhanced_content: str, + total_pages: int, + covered_pages: Set[int], + tree_root: Any, + ) -> list[dict]: + """Fallback scan: use ENHANCED (kreuzberg markdown) content. + + Scans the full ENHANCED text line-by-line for dense-token lines, + then maps each detected region back to a page number using the + tree's ``char_range → page_range`` mapping. + """ + char_page_map = KnowledgeCompiler._build_char_to_page_map( + tree_root, total_pages, + ) + if not char_page_map: + return [] + + breakpoints = [cp[0] for cp in char_page_map] + + results: list[dict] = [] + offset = 0 + for line in enhanced_content.split("\n"): + token_count = len(_NUM_TOKEN_RE.findall(line)) + if token_count >= _DENSE_LINE_MIN_TOKENS: + idx = bisect.bisect_right(breakpoints, offset) - 1 + page = char_page_map[max(0, idx)][1] if idx >= 0 else 1 + if page not in covered_pages: + results.append({ + "page": page, + "content": line[:_TARGETED_TABLE_MAX_CHARS], + "source": "content_scan:enhanced", + }) + covered_pages.add(page) + offset += len(line) + 1 # +1 for '\n' + + return results + + @staticmethod + def _build_char_to_page_map( + tree_root: Any, + total_pages: int, + ) -> list[tuple[int, int]]: + """Build a sorted (char_start, page_number) list from tree leaves. + + Enables efficient binary-search lookup from any character offset + in the ENHANCED content to the corresponding page number. + """ + entries: list[tuple[int, int]] = [] + + def _collect(node: Any) -> None: + children = getattr(node, "children", None) or [] + if isinstance(node, dict): + children = node.get("children", []) + pr = ( + getattr(node, "page_range", None) + if not isinstance(node, dict) + else node.get("page_range") + ) + cr = ( + getattr(node, "char_range", None) + if not isinstance(node, dict) + else node.get("char_range") + ) + if not children and cr and pr: + page = pr[0] if isinstance(pr, (list, tuple)) else pr + char_start = cr[0] if isinstance(cr, (list, tuple)) else cr + if page and char_start is not None: + entries.append((int(char_start), int(page))) + for ch in children: + _collect(ch) + + _collect(tree_root) + + if not entries: + return [(0, 1)] + entries.sort() + return entries + + def _find_force_ocr_candidates( + self, + file_path: str, + total_pages: Optional[int], + covered_pages: Set[int], + ) -> List[int]: + """Identify pages worth re-extracting with forced OCR. + + Returns 0-indexed page numbers for pages that have high numeric + density (suggesting tabular content) but are NOT already covered + by any table in the digest. The result is capped at + :data:`_FORCE_OCR_MAX_PAGES`. + """ + if not total_pages or total_pages <= 0: + return [] + + all_page_nums = list(range(1, total_pages + 1)) + try: + pages = DocumentExtractor.extract_pages(file_path, all_page_nums) + except Exception: + return [] + + candidates: List[int] = [] + for pc in pages: + if pc.page_number in covered_pages: + continue + if self._page_has_table_density(pc.content): + candidates.append(pc.page_number - 1) # 0-indexed for kreuzberg + + return sorted(candidates)[:_FORCE_OCR_MAX_PAGES] + + # ------------------------------------------------------------------ # + # Selective force-OCR re-extraction (P2) # + # ------------------------------------------------------------------ # + + async def _selective_force_ocr_tables( + self, + file_path: str, + gap_pages: List[int], + ) -> list[dict[str, Any]]: + """Extract text from gap pages using pypdf (no kreuzberg re-call). + + Earlier versions spawned a second kreuzberg extraction with + ``force_ocr_pages``, which doubled native memory pressure. + Using pypdf instead avoids Rust/native allocations entirely + while still capturing page text for the table digest. + + Args: + file_path: Path to the PDF. + gap_pages: 0-indexed page numbers. + + Returns: + List of table-compatible dicts (``markdown``, ``page_number``). + """ + if not gap_pages: + return [] + + capped = sorted(gap_pages)[:_FORCE_OCR_MAX_PAGES] + one_indexed = [p + 1 for p in capped] + try: + pages = DocumentExtractor.extract_pages(file_path, one_indexed) + except Exception: + return [] + + tables: list[dict[str, Any]] = [] + for pc in pages: + text = (pc.content or "").strip() + if text and self._page_has_table_density(text): + tables.append({ + "markdown": text, + "cells": [], + "page_number": pc.page_number, + }) + return tables + + # ------------------------------------------------------------------ # + # Summary index for embedding + BM25 fallback # + # ------------------------------------------------------------------ # + + async def _build_summary_index(self, manifest: CompileManifest) -> None: + """Build summary embedding + BM25 index for fallback search. + + Creates a lightweight index mapping each compiled file to: + - Its summary text + - Pre-computed embedding vector (384-dim, if EmbeddingUtil available) + - Tokenized summary with term frequencies (via TokenizerUtil) + + The index is saved to .cache/compile/summary_index.json and consumed + by search.py as a last-resort fallback when rga keyword search fails. + + Reuses ``self._embedding`` when available to avoid loading a duplicate + model into memory. Falls back to a fresh instance otherwise. + """ + try: + from sirchmunk.utils.tokenizer_util import TokenizerUtil + from sirchmunk.learnings.summary_index import CompileSummaryIndex, SummaryIndexEntry + + entries: List[SummaryIndexEntry] = [] + summaries: List[str] = [] + + for file_path, entry in manifest.files.items(): + if entry.summary: + entries.append(SummaryIndexEntry( + file_path=file_path, + summary=entry.summary, + )) + summaries.append(entry.summary) + + if not entries: + return + + tokenizer = TokenizerUtil() + for idx, entry in enumerate(entries): + tokens = tokenizer.segment(entry.summary) + entry.tokens = tokens + entry.token_freqs = {} + for t in tokens: + entry.token_freqs[t] = entry.token_freqs.get(t, 0) + 1 + + # Reuse the compiler's embedding client to avoid duplicate model load + try: + embedding_util = self._embedding + if embedding_util is None: + from sirchmunk.utils.embedding_util import EmbeddingUtil + embedding_util = EmbeddingUtil() + embedding_util.start_loading() + + await embedding_util._ensure_model_async(timeout=60) + + if embedding_util.is_ready(): + embeddings = await embedding_util.embed(summaries) + for i, emb in enumerate(embeddings): + entries[i].embedding = emb + await self._log.info( + f"Summary index: computed embeddings for {len(entries)} entries" + ) + except Exception as emb_exc: + await self._log.warning( + f"Summary index: embedding computation skipped: {emb_exc}" + ) + + index = CompileSummaryIndex(entries) + index.save(self._compile_dir / "summary_index.json") + + except Exception as exc: + await self._log.warning(f"Failed to build summary index: {exc}") + + # ------------------------------------------------------------------ # + # Manifest I/O # + # ------------------------------------------------------------------ # + + def _load_manifest(self) -> CompileManifest: + if self._manifest_path.exists(): + try: + return CompileManifest.from_json( + self._manifest_path.read_text(encoding="utf-8") + ) + except Exception: + pass + return CompileManifest() + + def _save_manifest(self, manifest: CompileManifest) -> None: + """Atomically persist the manifest via write-to-tmp + rename. + + This prevents partial JSON on disk if the process is killed mid-write. + """ + tmp_path = self._manifest_path.with_suffix(".json.tmp") + tmp_path.write_text(manifest.to_json(), encoding="utf-8") + tmp_path.replace(self._manifest_path) + + # ------------------------------------------------------------------ # + # Document catalog for search-time routing # + # ------------------------------------------------------------------ # + + def _build_document_catalog(self, manifest: CompileManifest) -> None: + """Generate a lightweight catalog mapping files to their tree root summaries. + + The catalog is consumed by FAST search to fuse query analysis with + LLM-driven document routing in a single prompt. Each entry carries + the filename and a truncated root summary (<= _MANIFEST_SUMMARY_MAX_LEN chars). + + Summary is sourced from the manifest (populated during Phase 2 compile), + with a tree-root fallback for backward compatibility. + """ + tree_cache = self._compile_dir / "trees" + entries: List[Dict[str, str]] = [] + + for file_path, entry in manifest.files.items(): + summary = entry.summary # Primary: manifest-persisted summary + + # Fallback: read from tree root if manifest summary is empty + if not summary and entry.has_tree and tree_cache.exists(): + tree_file = tree_cache / f"{entry.file_hash}.json" + if tree_file.exists(): + try: + tree = DocumentTree.from_json( + tree_file.read_text(encoding="utf-8"), + ) + if tree.root and tree.root.summary: + summary = tree.root.summary[:_MANIFEST_SUMMARY_MAX_LEN] + except Exception: + pass + + entries.append({ + "path": file_path, + "name": Path(file_path).name, + "summary": summary, + }) + + catalog_path = self._compile_dir / "document_catalog.json" + catalog_path.write_text( + json.dumps(entries, ensure_ascii=False, indent=2), + encoding="utf-8", + ) diff --git a/src/sirchmunk/learnings/knowledge_base.py b/src/sirchmunk/learnings/knowledge_base.py index 387b368..7296f71 100644 --- a/src/sirchmunk/learnings/knowledge_base.py +++ b/src/sirchmunk/learnings/knowledge_base.py @@ -120,11 +120,14 @@ async def _extract_evidence_for_file( confidence_threshold: float, top_k_snippets: int, verbose: bool, + tree_indexer=None, ) -> Optional[EvidenceUnit]: - """Extract evidence from a single file via Monte Carlo sampling. + """Extract evidence from a single file. - Performs text extraction followed by LLM-driven region-of-interest - identification. Designed to run concurrently for multiple files. + When a tree index exists for the file, uses LLM-driven tree navigation + to locate relevant sections precisely, then runs Monte Carlo sampling + within those narrowed regions. Falls back to full-document Monte + Carlo sampling otherwise. Args: file_path_or_url: Absolute path or URL to the document. @@ -133,6 +136,7 @@ async def _extract_evidence_for_file( confidence_threshold: Minimum confidence for evidence acceptance. top_k_snippets: Maximum evidence snippets per document. verbose: Whether to enable verbose logging. + tree_indexer: Optional DocumentTreeIndexer for tree-based navigation. Returns: EvidenceUnit on success, None on extraction failure. @@ -141,6 +145,28 @@ async def _extract_evidence_for_file( extraction_result = await fast_extract(file_path=file_path_or_url) doc_content: str = extraction_result.content + tree_path_ids = None + + # Try tree-based navigation for focused extraction + if tree_indexer is not None: + tree = tree_indexer.load_tree(file_path_or_url) + if tree is not None: + await self._log.info( + f"[KnowledgeBase] Using tree index for {Path(file_path_or_url).name}" + ) + leaves = await tree_indexer.navigate(tree, query) + if leaves: + # Narrow doc_content to matched regions + tree_path_ids = [n.node_id for n in leaves] + segments = [] + for node in leaves: + start, end = node.char_range + segment = doc_content[start:end] + if segment.strip(): + segments.append(segment) + if segments: + doc_content = "\n\n---\n\n".join(segments) + sampler = MonteCarloEvidenceSampling( llm=self.llm, doc_content=doc_content, @@ -162,6 +188,7 @@ async def _extract_evidence_for_file( snippets=roi_result.snippets, extracted_at=datetime.now(), conflict_group=[], + tree_path=tree_path_ids, ) self.llm_usages.extend(sampler.llm_usages) return evidence_unit @@ -181,6 +208,7 @@ async def build( top_k_snippets: Optional[int] = 5, confidence_threshold: Optional[float] = 8.0, verbose: bool = True, + tree_indexer=None, ) -> Union[KnowledgeCluster, None]: """Build a knowledge cluster from retrieved information and metadata. @@ -196,6 +224,8 @@ async def build( top_k_snippets: Max evidence snippets per file. confidence_threshold: Min confidence for evidence acceptance. verbose: Enable verbose logging. + tree_indexer: Optional DocumentTreeIndexer for tree-navigated + evidence extraction (uses compiled tree indices when available). Returns: KnowledgeCluster on success, None if no evidence found. @@ -223,6 +253,7 @@ async def build( confidence_threshold=confidence_threshold, top_k_snippets=top_k_snippets, verbose=verbose, + tree_indexer=tree_indexer, ) for info in retrieved_infos ] diff --git a/src/sirchmunk/learnings/lint.py b/src/sirchmunk/learnings/lint.py new file mode 100644 index 0000000..e5baa6f --- /dev/null +++ b/src/sirchmunk/learnings/lint.py @@ -0,0 +1,213 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Knowledge lint — health checks and auto-fixes for the knowledge network. + +Inspired by LLM Wiki's Lint operation: validates cluster integrity, +detects stale evidence, and cleans orphaned tree indices. +""" + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Union + +from sirchmunk.schema.knowledge import KnowledgeCluster, Lifecycle +from sirchmunk.storage.knowledge_storage import KnowledgeStorage +from sirchmunk.utils import LogCallback, create_logger + + +@dataclass +class LintIssue: + """A single lint finding.""" + + severity: str # "error", "warning", "info" + category: str # "stale_evidence", "orphan_tree", "empty_cluster", etc. + message: str + cluster_id: Optional[str] = None + file_path: Optional[str] = None + auto_fixed: bool = False + + def to_dict(self) -> Dict[str, Any]: + return { + "severity": self.severity, + "category": self.category, + "message": self.message, + "cluster_id": self.cluster_id, + "file_path": self.file_path, + "auto_fixed": self.auto_fixed, + } + + +@dataclass +class LintReport: + """Summary of a lint run.""" + + total_clusters_checked: int = 0 + total_trees_checked: int = 0 + issues: List[LintIssue] = field(default_factory=list) + auto_fixes_applied: int = 0 + + @property + def errors(self) -> int: + return sum(1 for i in self.issues if i.severity == "error") + + @property + def warnings(self) -> int: + return sum(1 for i in self.issues if i.severity == "warning") + + def to_dict(self) -> Dict[str, Any]: + return { + "total_clusters_checked": self.total_clusters_checked, + "total_trees_checked": self.total_trees_checked, + "errors": self.errors, + "warnings": self.warnings, + "auto_fixes_applied": self.auto_fixes_applied, + "issues": [i.to_dict() for i in self.issues], + } + + +class KnowledgeLint: + """Validate the health of the knowledge network and apply auto-fixes.""" + + def __init__( + self, + knowledge_storage: KnowledgeStorage, + work_path: Union[str, Path], + log_callback: LogCallback = None, + ): + self._storage = knowledge_storage + self._work_path = Path(work_path).expanduser().resolve() + self._tree_dir = self._work_path / ".cache" / "compile" / "trees" + self._manifest_path = self._work_path / ".cache" / "compile" / "manifest.json" + self._log = create_logger(log_callback=log_callback) + + async def run(self, *, auto_fix: bool = False) -> LintReport: + """Execute all lint checks and optionally apply auto-fixes.""" + report = LintReport() + + await self._log.info("[Lint] Starting knowledge health check") + + # Check clusters + await self._check_clusters(report, auto_fix=auto_fix) + + # Check orphaned tree caches + await self._check_orphan_trees(report, auto_fix=auto_fix) + + # Check manifest consistency + await self._check_manifest(report) + + await self._log.info( + f"[Lint] Done — clusters={report.total_clusters_checked}, " + f"trees={report.total_trees_checked}, " + f"errors={report.errors}, warnings={report.warnings}, " + f"fixes={report.auto_fixes_applied}" + ) + return report + + async def _check_clusters(self, report: LintReport, auto_fix: bool) -> None: + """Validate each knowledge cluster.""" + all_clusters = await self._storage.find("", limit=10000) + report.total_clusters_checked = len(all_clusters) + + for cluster in all_clusters: + # Check: empty content + if not cluster.content or ( + isinstance(cluster.content, str) and len(cluster.content.strip()) < 10 + ): + report.issues.append(LintIssue( + severity="warning", + category="empty_cluster", + message=f"Cluster has empty or minimal content", + cluster_id=cluster.id, + )) + + # Check: stale evidence (source files no longer exist) + stale_count = 0 + for ev in cluster.evidences: + fp = str(ev.file_or_url) + if fp.startswith("/") and not Path(fp).exists(): + stale_count += 1 + + if stale_count > 0: + report.issues.append(LintIssue( + severity="warning", + category="stale_evidence", + message=f"{stale_count} evidence source(s) no longer exist", + cluster_id=cluster.id, + )) + + if auto_fix and stale_count == len(cluster.evidences): + cluster.lifecycle = Lifecycle.DEPRECATED + await self._storage.update(cluster) + report.auto_fixes_applied += 1 + report.issues[-1].auto_fixed = True + + # Check: no queries and no evidences (orphan cluster) + if not cluster.evidences and not cluster.queries: + report.issues.append(LintIssue( + severity="info", + category="orphan_cluster", + message="Cluster has no evidence and no queries", + cluster_id=cluster.id, + )) + + # Check: isolated cluster (no WeakSemanticEdge connections) + if not cluster.related_clusters and cluster.evidences: + report.issues.append(LintIssue( + severity="info", + category="isolated_cluster", + message="Cluster has no cross-references to other clusters", + cluster_id=cluster.id, + )) + + async def _check_orphan_trees(self, report: LintReport, auto_fix: bool) -> None: + """Find tree cache files whose source documents no longer exist.""" + if not self._tree_dir.exists(): + return + + manifest = self._load_manifest() + # Build set of valid file hashes from the manifest + valid_hashes: Set[str] = set() + for entry_data in manifest.get("files", {}).values(): + fh = entry_data.get("file_hash", "") + if fh: + valid_hashes.add(fh) + + tree_files = list(self._tree_dir.glob("*.json")) + report.total_trees_checked = len(tree_files) + + for tf in tree_files: + tree_hash = tf.stem + if tree_hash not in valid_hashes: + report.issues.append(LintIssue( + severity="info", + category="orphan_tree", + message=f"Tree cache has no matching manifest entry", + file_path=str(tf), + )) + if auto_fix: + tf.unlink(missing_ok=True) + report.auto_fixes_applied += 1 + report.issues[-1].auto_fixed = True + + async def _check_manifest(self, report: LintReport) -> None: + """Validate manifest references.""" + manifest = self._load_manifest() + files = manifest.get("files", {}) + + for fp, entry_data in files.items(): + if not Path(fp).exists(): + report.issues.append(LintIssue( + severity="warning", + category="stale_manifest", + message=f"Manifest references non-existent file", + file_path=fp, + )) + + def _load_manifest(self) -> Dict[str, Any]: + if self._manifest_path.exists(): + try: + return json.loads(self._manifest_path.read_text(encoding="utf-8")) + except Exception: + pass + return {} diff --git a/src/sirchmunk/learnings/summary_index.py b/src/sirchmunk/learnings/summary_index.py new file mode 100644 index 0000000..7ec355a --- /dev/null +++ b/src/sirchmunk/learnings/summary_index.py @@ -0,0 +1,255 @@ +"""Compile-time summary index for embedding + BM25 fallback retrieval. + +This module provides a lightweight, file-level index that combines: +- Semantic similarity via pre-computed embeddings (384-dim MiniLM) +- Lexical matching via BM25 scoring (TokenizerUtil segmentation) + +Used ONLY as a fallback when rga keyword search returns zero results. +""" + +import json +import math +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +@dataclass +class SummaryIndexEntry: + """Single file entry in the summary index.""" + file_path: str + summary: str + embedding: Optional[List[float]] = None # 384-dim, pre-normalized + tokens: Optional[List[str]] = None # TokenizerUtil.segment() output + token_freqs: Optional[Dict[str, int]] = None # pre-computed term frequencies + + +class CompileSummaryIndex: + """Pre-computed summary index for hybrid embedding + BM25 fallback search. + + This index is built at compile time and loaded at search time. + It provides a fallback retrieval mechanism when rga keyword search + returns zero results, combining semantic similarity (embedding cosine) + with lexical matching (BM25). + + The fusion algorithm uses Sigmoid Z-Score normalization: + 1. Compute raw scores from both channels + 2. Z-Score normalize each channel independently + 3. Weighted combination: alpha * z_embedding + (1-alpha) * z_bm25 + 4. Sigmoid activation for final score + """ + + # BM25 parameters (Okapi BM25 standard defaults) + _BM25_K1: float = 1.5 + _BM25_B: float = 0.75 + + # Fusion parameters + _DEFAULT_ALPHA: float = 0.5 # embedding weight; (1-alpha) = BM25 weight + + # Z-Score fallback for missing channel + _MISSING_CHANNEL_Z: float = -3.0 # ~0.1 percentile + + def __init__(self, entries: List[SummaryIndexEntry]) -> None: + self._entries = entries + self._num_docs = len(entries) + self._avg_doc_len = self._compute_avg_doc_len() + self._doc_freqs: Dict[str, int] = self._compute_doc_freqs() + + def _compute_avg_doc_len(self) -> float: + """Compute average document length (in tokens) across all entries.""" + lengths = [len(e.tokens or []) for e in self._entries] + return sum(lengths) / max(1, len(lengths)) + + def _compute_doc_freqs(self) -> Dict[str, int]: + """Compute document frequency for each unique token.""" + df: Dict[str, int] = {} + for entry in self._entries: + if entry.token_freqs: + for token in entry.token_freqs: + df[token] = df.get(token, 0) + 1 + return df + + @classmethod + def load(cls, index_path: Path) -> Optional["CompileSummaryIndex"]: + """Load index from JSON file. Returns None on failure.""" + try: + if not index_path.exists(): + return None + data = json.loads(index_path.read_text(encoding="utf-8")) + entries = [] + for item in data.get("entries", []): + entries.append(SummaryIndexEntry( + file_path=item["file_path"], + summary=item.get("summary", ""), + embedding=item.get("embedding"), + tokens=item.get("tokens"), + token_freqs=item.get("token_freqs"), + )) + if not entries: + return None + return cls(entries) + except Exception as exc: + logger.warning("Failed to load summary index from %s: %s", index_path, exc) + return None + + def save(self, index_path: Path) -> None: + """Persist index to JSON file.""" + index_path.parent.mkdir(parents=True, exist_ok=True) + data = { + "version": 1, + "num_entries": len(self._entries), + "entries": [ + { + "file_path": e.file_path, + "summary": e.summary, + "embedding": e.embedding, + "tokens": e.tokens, + "token_freqs": e.token_freqs, + } + for e in self._entries + ], + } + index_path.write_text( + json.dumps(data, ensure_ascii=False), + encoding="utf-8", + ) + logger.info("Summary index saved: %d entries -> %s", len(self._entries), index_path) + + def search( + self, + query_embedding: Optional[List[float]], + query_tokens: List[str], + top_k: int = 5, + alpha: float = _DEFAULT_ALPHA, + ) -> List[Tuple[str, float]]: + """Hybrid search combining embedding cosine similarity and BM25. + + Uses Sigmoid Z-Score fusion: + 1. Compute raw embedding cosine sim and BM25 score per document + 2. Z-Score normalize each channel + 3. Weighted linear combination + 4. Sigmoid activation + + Args: + query_embedding: 384-dim query vector (None to use BM25 only). + query_tokens: Tokenized query from TokenizerUtil.segment(). + top_k: Maximum number of results. + alpha: Embedding weight in [0, 1]. BM25 weight = 1 - alpha. + + Returns: + List of (file_path, fusion_score) sorted descending by score. + """ + if not self._entries: + return [] + + # Compute raw scores + emb_scores: List[Optional[float]] = [] + bm25_scores: List[float] = [] + + has_embedding = query_embedding is not None + + for entry in self._entries: + # Embedding channel + if has_embedding and entry.embedding: + emb_scores.append(self._cosine_similarity(query_embedding, entry.embedding)) + else: + emb_scores.append(None) + + # BM25 channel + bm25_scores.append(self._bm25_score(query_tokens, entry)) + + # Z-Score normalization + z_emb = self._z_score_normalize(emb_scores) + z_bm25 = self._z_score_normalize(bm25_scores) + + # Sigmoid fusion + results: List[Tuple[str, float]] = [] + for i, entry in enumerate(self._entries): + z_e = z_emb[i] if z_emb[i] is not None else self._MISSING_CHANNEL_Z + z_b = z_bm25[i] if z_bm25[i] is not None else self._MISSING_CHANNEL_Z + + combined = alpha * z_e + (1.0 - alpha) * z_b + score = 1.0 / (1.0 + math.exp(-combined)) + results.append((entry.file_path, score)) + + # Sort descending and return top_k + results.sort(key=lambda x: x[1], reverse=True) + return results[:top_k] + + def _bm25_score(self, query_tokens: List[str], entry: SummaryIndexEntry) -> float: + """Compute BM25 score for a single document. + + Uses standard Okapi BM25 formula: + score = sum over query terms: + IDF(t) * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * dl / avgdl)) + """ + if not query_tokens or not entry.token_freqs: + return 0.0 + + dl = len(entry.tokens or []) + score = 0.0 + + for token in query_tokens: + tf = entry.token_freqs.get(token, 0) + if tf == 0: + continue + + # IDF: log((N - df + 0.5) / (df + 0.5) + 1) + df = self._doc_freqs.get(token, 0) + idf = math.log((self._num_docs - df + 0.5) / (df + 0.5) + 1.0) + + # TF component + tf_component = (tf * (self._BM25_K1 + 1.0)) / ( + tf + self._BM25_K1 * (1.0 - self._BM25_B + self._BM25_B * dl / max(1.0, self._avg_doc_len)) + ) + + score += idf * tf_component + + return score + + @staticmethod + def _cosine_similarity(a: List[float], b: List[float]) -> float: + """Compute cosine similarity between two vectors. + + When embeddings are pre-normalized (L2 norm = 1), this reduces + to a simple dot product. + """ + if len(a) != len(b): + return 0.0 + dot = sum(x * y for x, y in zip(a, b)) + # Clamp to [-1, 1] for numerical safety + return max(-1.0, min(1.0, dot)) + + @staticmethod + def _z_score_normalize(scores: List[Optional[float]]) -> List[Optional[float]]: + """Z-Score normalize a list of scores, preserving None entries. + + None entries remain None (handled as _MISSING_CHANNEL_Z at fusion). + """ + valid = [s for s in scores if s is not None] + if len(valid) < 2: + # Not enough data points for meaningful normalization + return scores + + mean = sum(valid) / len(valid) + variance = sum((s - mean) ** 2 for s in valid) / len(valid) + std = math.sqrt(variance) if variance > 0 else 1.0 + + if std < 1e-9: + # All scores identical — return zeros + return [0.0 if s is not None else None for s in scores] + + return [(s - mean) / std if s is not None else None for s in scores] + + @property + def num_entries(self) -> int: + """Number of indexed documents.""" + return self._num_docs + + @property + def has_embeddings(self) -> bool: + """Whether any entry has a pre-computed embedding.""" + return any(e.embedding is not None for e in self._entries) diff --git a/src/sirchmunk/learnings/toc_extractor.py b/src/sirchmunk/learnings/toc_extractor.py new file mode 100644 index 0000000..0197485 --- /dev/null +++ b/src/sirchmunk/learnings/toc_extractor.py @@ -0,0 +1,964 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +TOC (Table of Contents) extractor — multi-layer fallback strategy. + +Extracts hierarchical table-of-contents structures from various document +formats (PDF, Markdown, DOCX, HTML) using a layered approach: + + Layer 1 — pypdf native outline (highest confidence, zero cost) + Layer 2 — pdfminer.six detailed parsing (fallback for pypdf) + Layer 3 — Text heading pattern detection (for documents without bookmarks) + Layer 4 — LLM-assisted inference (optional, last resort) + +The extracted TOCEntry list is consumed by the tree indexer to accelerate +tree construction. +""" + +import json +import logging +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, ClassVar, List, Optional + +logger = logging.getLogger(__name__) + +# Known heading-style prefixes across locales (English, Chinese, etc.) +_HEADING_STYLE_PREFIXES = ("Heading", "heading", "\u6807\u9898") # "标题" = Chinese + + +# --------------------------------------------------------------------------- +# Data models +# --------------------------------------------------------------------------- + + +@dataclass +class TOCEntry: + """Single entry in an extracted table of contents. + + Attributes: + title: Section title text. + level: Heading depth (1 = top-level section, 2 = subsection, …). + char_start: Character offset in the extracted full text. + char_end: End character offset (exclusive), or None if unresolved. + page_start: 1-indexed page number, or None if unknown. + page_end: End page number (inclusive), or None. + children: Nested sub-entries forming a tree. + source: Which extraction layer produced this entry + ("pypdf", "pdfminer", "heading", "markdown", "docx", + "html", "llm"). + """ + + title: str + level: int # 1=section, 2=subsection, … + char_start: int = 0 + char_end: Optional[int] = None + page_start: Optional[int] = None + page_end: Optional[int] = None + children: List["TOCEntry"] = field(default_factory=list) + source: str = "" + + +@dataclass +class TocResult: + """Complete TOC extraction result with quality metadata. + + Attributes: + entries: Ordered list of TOCEntry objects. + source: Primary extraction method that produced the result. + confidence: Estimated quality score (0.0–1.0). + page_count: Total pages in the source document, if known. + """ + + entries: List[TOCEntry] = field(default_factory=list) + source: str = "" + confidence: float = 0.0 + page_count: Optional[int] = None + + +# --------------------------------------------------------------------------- +# Layer 1: pypdf native outline +# --------------------------------------------------------------------------- + + +class PypdfOutlineExtractor: + """Layer 1: Extract TOC from PDF native outline/bookmarks using pypdf. + + Highest confidence (0.9) — relies on the PDF producer embedding + explicit bookmarks. Zero external cost. + """ + + @staticmethod + def extract(file_path: str | Path) -> TocResult: + """Extract TOC from PDF outline. + + Args: + file_path: Path to the PDF file. + + Returns: + TocResult with entries and page_count populated, + or an empty TocResult on failure. + """ + try: + from pypdf import PdfReader + + reader = PdfReader(str(file_path)) + outline = reader.outline + page_count = len(reader.pages) + + if not outline: + return TocResult(source="pypdf", page_count=page_count) + + entries: List[TOCEntry] = [] + PypdfOutlineExtractor._parse_outline( + reader, outline, entries, level=1, + ) + + if not entries: + return TocResult(source="pypdf", page_count=page_count) + + return TocResult( + entries=entries, + source="pypdf", + confidence=0.9, + page_count=page_count, + ) + except Exception as exc: + logger.debug("pypdf outline extraction failed: %s", exc) + return TocResult(source="pypdf") + + @staticmethod + def _parse_outline( + reader: Any, + outline_items: list, + entries: List[TOCEntry], + level: int, + ) -> None: + """Recursively parse pypdf outline items into TOCEntry list.""" + for item in outline_items: + if isinstance(item, list): + # Nested list → sub-bookmarks; attach to last entry + if entries: + sub: List[TOCEntry] = [] + PypdfOutlineExtractor._parse_outline( + reader, item, sub, level=level + 1, + ) + entries[-1].children.extend(sub) + else: + PypdfOutlineExtractor._parse_outline( + reader, item, entries, level=level, + ) + else: + try: + title = item.title if hasattr(item, "title") else str(item) + page_num: Optional[int] = None + try: + # get_destination_page_number returns 0-indexed + raw = reader.get_destination_page_number(item) + if raw is not None: + page_num = raw + 1 # convert to 1-indexed + except Exception: + pass + entries.append(TOCEntry( + title=title.strip(), + level=level, + char_start=0, + page_start=page_num, + source="pypdf", + )) + except Exception: + continue + + +# --------------------------------------------------------------------------- +# Layer 2: pdfminer.six detailed parsing +# --------------------------------------------------------------------------- + + +class PdfminerOutlineExtractor: + """Layer 2: Extract TOC using pdfminer.six for more detailed parsing. + + Falls back here when pypdf yields insufficient entries. + Confidence 0.85 — pdfminer exposes more detail but requires + manual page-number resolution. + """ + + @staticmethod + def extract(file_path: str | Path) -> TocResult: + """Extract TOC using pdfminer's outline parser. + + Args: + file_path: Path to the PDF file. + + Returns: + TocResult with entries populated, or empty on failure. + """ + try: + from pdfminer.pdfdocument import PDFDocument, PDFNoOutlines + from pdfminer.pdfpage import PDFPage + from pdfminer.pdfparser import PDFParser + from pdfminer.psparser import LIT + + fp = open(str(file_path), "rb") + try: + parser = PDFParser(fp) + document = PDFDocument(parser) + + # Build page-object-id → 1-indexed page number mapping + pages = list(PDFPage.create_pages(document)) + page_count = len(pages) + objid_to_pagenum = { + page.pageid: idx + 1 + for idx, page in enumerate(pages) + } + + entries: List[TOCEntry] = [] + try: + for level, title, dest, action, _se in document.get_outlines(): + page_num = PdfminerOutlineExtractor._resolve_page( + dest, action, objid_to_pagenum, document, + ) + entries.append(TOCEntry( + title=str(title).strip() if title else "", + level=level, + char_start=0, + page_start=page_num, + source="pdfminer", + )) + except PDFNoOutlines: + pass + + if not entries: + return TocResult(source="pdfminer", page_count=page_count) + + return TocResult( + entries=entries, + source="pdfminer", + confidence=0.85, + page_count=page_count, + ) + finally: + fp.close() + except Exception as exc: + logger.debug("pdfminer outline extraction failed: %s", exc) + return TocResult(source="pdfminer") + + @staticmethod + def _resolve_page( + dest: Any, + action: Any, + objid_to_pagenum: dict, + document: Any, + ) -> Optional[int]: + """Resolve a pdfminer destination/action to a 1-indexed page number.""" + try: + from pdfminer.pdfparser import PDFStream + from pdfminer.pdftypes import resolve1 + + # Try dest first + target = dest + if target is None and action is not None: + # GoTo action: action dict may have a 'D' key + if isinstance(action, dict): + target = action.get("D") + + if target is None: + return None + + # Resolve indirect objects + target = resolve1(target) + + if isinstance(target, list) and len(target) > 0: + page_ref = resolve1(target[0]) + if hasattr(page_ref, "objid"): + return objid_to_pagenum.get(page_ref.objid) + elif hasattr(target, "objid"): + return objid_to_pagenum.get(target.objid) + except Exception: + pass + return None + + +# --------------------------------------------------------------------------- +# Layer 3: Text heading pattern detection +# --------------------------------------------------------------------------- + + +class HeadingTocExtractor: + """Layer 3: Infer TOC from document text structure (heading patterns). + + Handles Markdown headings, numbered sections, and common structural + keywords. Confidence 0.6 — heuristic-based, lower precision. + """ + + # Regex for Markdown ATX headings: # Title, ## Subtitle, … + _MD_HEADING_RE: ClassVar[re.Pattern] = re.compile( + r"^(#{1,6})\s+(.+)$", re.MULTILINE, + ) + + # Regex for numbered section patterns: "1.", "1.1", "1.1.1", … + _NUMBERED_RE: ClassVar[re.Pattern] = re.compile( + r"^(\d+(?:\.\d+)*)[.\s]+(.+)$", re.MULTILINE, + ) + + # Common structural keywords (case-insensitive) + _STRUCTURAL_KEYWORDS: ClassVar[tuple] = ( + "ITEM", "PART", "CHAPTER", "SECTION", "ARTICLE", + "APPENDIX", "EXHIBIT", "SCHEDULE", "ANNEX", + ) + + # Max characters for a candidate heading line + _MAX_HEADING_LINE_LEN: ClassVar[int] = 120 + + @staticmethod + def extract(content: str, mime_type: str = "") -> TocResult: + """Infer TOC from text content by detecting heading patterns. + + Tries strategies in order: + 1. Markdown ATX headings (``#`` syntax) + 2. Numbered section patterns (``1.``, ``1.1``, …) + 3. Structural keyword detection (ITEM, PART, CHAPTER, …) + + Args: + content: Full extracted text of the document. + mime_type: Optional MIME type hint (unused currently). + + Returns: + TocResult with char_position-based entries. + """ + if not content or len(content.strip()) < 50: + return TocResult(source="heading") + + # Strategy 1: Markdown headings + entries = HeadingTocExtractor._extract_markdown_headings(content) + if entries: + return TocResult( + entries=entries, + source="heading", + confidence=0.7, + ) + + # Strategy 2: Numbered sections + entries = HeadingTocExtractor._extract_numbered_sections(content) + if entries: + return TocResult( + entries=entries, + source="heading", + confidence=0.6, + ) + + # Strategy 3: Structural keywords + heuristic + entries = HeadingTocExtractor._extract_structural_headings(content) + if entries: + return TocResult( + entries=entries, + source="heading", + confidence=0.5, + ) + + return TocResult(source="heading") + + @staticmethod + def _extract_markdown_headings(content: str) -> List[TOCEntry]: + """Extract headings from Markdown ATX syntax (# / ## / ###).""" + matches = list(HeadingTocExtractor._MD_HEADING_RE.finditer(content)) + if not matches: + return [] + + entries: List[TOCEntry] = [] + for m in matches: + hashes, title = m.group(1), m.group(2).strip() + if title: + entries.append(TOCEntry( + title=title, + level=len(hashes), + char_start=m.start(), + source="heading", + )) + return entries + + @staticmethod + def _extract_numbered_sections(content: str) -> List[TOCEntry]: + """Extract numbered section headings (1., 1.1, 1.1.1, …).""" + matches = list(HeadingTocExtractor._NUMBERED_RE.finditer(content)) + if not matches: + return [] + + entries: List[TOCEntry] = [] + for m in matches: + number_part = m.group(1) + title_part = m.group(2).strip() + # Line length check — skip long lines (likely not headings) + line_len = m.end() - m.start() + if line_len > HeadingTocExtractor._MAX_HEADING_LINE_LEN: + continue + if not title_part: + continue + level = number_part.count(".") + 1 + entries.append(TOCEntry( + title=f"{number_part} {title_part}", + level=level, + char_start=m.start(), + source="heading", + )) + return entries + + @staticmethod + def _extract_structural_headings(content: str) -> List[TOCEntry]: + """Detect common structural keywords as section boundaries.""" + # Build pattern: ITEM 1, PART I, CHAPTER 1, etc. + kw_pattern = "|".join(HeadingTocExtractor._STRUCTURAL_KEYWORDS) + pattern = re.compile( + rf"^({kw_pattern})\s+(\w+[\w .:\-]*)$", + re.MULTILINE | re.IGNORECASE, + ) + matches = list(pattern.finditer(content)) + if not matches: + return [] + + entries: List[TOCEntry] = [] + for m in matches: + keyword = m.group(1).upper() + rest = m.group(2).strip() + title = f"{keyword} {rest}" + # Determine level based on keyword + if keyword in ("PART", "CHAPTER"): + level = 1 + elif keyword in ("ITEM", "SECTION", "ARTICLE"): + level = 2 + else: + level = 3 + entries.append(TOCEntry( + title=title, + level=level, + char_start=m.start(), + source="heading", + )) + return entries + + +# --------------------------------------------------------------------------- +# Layer 4: LLM-assisted inference (optional) +# --------------------------------------------------------------------------- + + +class LlmTocExtractor: + """Layer 4: Use LLM to infer TOC from document content. + + This is the last-resort fallback. Requires an ``llm_caller`` that + supports ``await llm_caller.achat(messages)``. If no caller is + provided, returns an empty result immediately. + + Confidence 0.7 — LLM may hallucinate structure. + """ + + # Maximum characters sent to the LLM to stay within token limits + _MAX_CONTENT_CHARS: ClassVar[int] = 8_000 + + _PROMPT_TEMPLATE: ClassVar[str] = ( + "Analyze the following document excerpt and extract its " + "hierarchical table of contents (TOC) structure.\n\n" + "Return a JSON array where each element has:\n" + ' - "title": section title text\n' + ' - "level": integer heading depth (1=top, 2=sub, 3=subsub)\n\n' + "Only include actual section/chapter headings, not every paragraph.\n" + "Return ONLY the JSON array, no other text.\n\n" + "Document excerpt:\n---\n{content}\n---" + ) + + @staticmethod + async def extract( + content: str, + llm_caller: Any | None = None, + ) -> TocResult: + """Infer TOC using LLM analysis. + + Args: + content: Full extracted text of the document. + llm_caller: An object with ``achat(messages)`` method. + If None, returns an empty result. + + Returns: + TocResult with LLM-inferred entries. + """ + if llm_caller is None: + return TocResult(source="llm") + + if not content or len(content.strip()) < 100: + return TocResult(source="llm") + + try: + # Truncate content to fit token budget + truncated = content[:LlmTocExtractor._MAX_CONTENT_CHARS] + prompt = LlmTocExtractor._PROMPT_TEMPLATE.format(content=truncated) + + resp = await llm_caller.achat([{"role": "user", "content": prompt}]) + raw = resp.content.strip() + + entries = LlmTocExtractor._parse_response(raw, content) + if not entries: + return TocResult(source="llm") + + return TocResult( + entries=entries, + source="llm", + confidence=0.7, + ) + except Exception as exc: + logger.debug("LLM TOC extraction failed: %s", exc) + return TocResult(source="llm") + + @staticmethod + def _parse_response(raw: str, content: str) -> List[TOCEntry]: + """Parse LLM JSON response into TOCEntry list with char_positions.""" + # Strip markdown code fences if present + cleaned = raw.strip() + if cleaned.startswith("```"): + lines = cleaned.split("\n") + # Remove first and last fence lines + lines = [l for l in lines if not l.strip().startswith("```")] + cleaned = "\n".join(lines) + + try: + items = json.loads(cleaned) + except (json.JSONDecodeError, TypeError): + return [] + + if not isinstance(items, list): + return [] + + content_lower = content.lower() + search_from = 0 + entries: List[TOCEntry] = [] + + for item in items: + if not isinstance(item, dict): + continue + title = str(item.get("title", "")).strip() + level = int(item.get("level", 1)) + if not title: + continue + + # Try to locate title in content for char_position + pos = content_lower.find(title.lower(), search_from) + if pos >= 0: + char_start = pos + search_from = pos + len(title) + else: + # Fallback: try from beginning + pos = content_lower.find(title.lower()) + char_start = pos if pos >= 0 else search_from + + entries.append(TOCEntry( + title=title, + level=max(1, min(level, 6)), + char_start=char_start, + source="llm", + )) + + return entries + + +# --------------------------------------------------------------------------- +# Format-specific extractors (non-PDF) +# --------------------------------------------------------------------------- + + +class DocxTocExtractor: + """Extract TOC from DOCX heading styles using python-docx.""" + + @staticmethod + def extract(file_path: str | Path) -> TocResult: + """Extract TOC from DOCX heading styles. + + Args: + file_path: Path to the DOCX file. + + Returns: + TocResult with entries from heading styles. + """ + try: + import docx + + doc = docx.Document(str(file_path)) + entries: List[TOCEntry] = [] + for para in doc.paragraphs: + style_name = para.style.name or "" + matched_prefix = "" + for prefix in _HEADING_STYLE_PREFIXES: + if style_name.startswith(prefix): + matched_prefix = prefix + break + if not matched_prefix: + continue + level_str = style_name[len(matched_prefix):].strip() + try: + level = int(level_str) if level_str else 1 + except ValueError: + level = 1 + title = para.text.strip() + if title: + entries.append(TOCEntry( + title=title, + level=level, + char_start=0, + source="docx", + )) + + if not entries: + return TocResult(source="docx") + return TocResult(entries=entries, source="docx", confidence=0.85) + except Exception as exc: + logger.debug("DOCX TOC extraction failed: %s", exc) + return TocResult(source="docx") + + +class HtmlTocExtractor: + """Extract TOC from HTML heading tags (

).""" + + _HTML_HEADING_RE: ClassVar[re.Pattern] = re.compile( + r"]*>(.*?)", + re.IGNORECASE | re.DOTALL, + ) + + @staticmethod + def extract(content: str) -> TocResult: + """Extract TOC from HTML heading tags. + + Args: + content: HTML text content. + + Returns: + TocResult with entries from

tags. + """ + try: + matches = HtmlTocExtractor._HTML_HEADING_RE.findall(content) + if not matches: + return TocResult(source="html") + + entries: List[TOCEntry] = [] + for level_str, raw_title in matches: + title = re.sub(r"<[^>]+>", "", raw_title).strip() + if title: + entries.append(TOCEntry( + title=title, + level=int(level_str), + char_start=0, + source="html", + )) + + if not entries: + return TocResult(source="html") + return TocResult(entries=entries, source="html", confidence=0.8) + except Exception as exc: + logger.debug("HTML TOC extraction failed: %s", exc) + return TocResult(source="html") + + +# --------------------------------------------------------------------------- +# Orchestrator: multi-layer fallback +# --------------------------------------------------------------------------- + + +class TOCExtractor: + """Orchestrates multi-layer TOC extraction with fallback strategy. + + All methods are static/classmethod — no instance state required. + The main ``extract()`` entry point dispatches by file extension and + applies the layered fallback for PDF files. + + Layer priority for PDFs: + 1. pypdf native outline (confidence 0.9) + 2. pdfminer.six detailed parsing (confidence 0.85) + 3. Text heading detection (confidence 0.5–0.7) + 4. LLM-assisted inference (confidence 0.7, optional) + + Design constraints: + - Layers 1–3 are pure-local, zero LLM calls + - Layer 4 is optional (requires llm_caller) + - Each layer is independently try-excepted; failure never blocks + subsequent layers + """ + + # Minimum entries to consider a TOC extraction successful + _MIN_ENTRIES_THRESHOLD: ClassVar[int] = 3 + + @staticmethod + def _build_hierarchy(flat_entries: List["TOCEntry"]) -> List["TOCEntry"]: + """Convert flat TocEntry list to nested tree using level field. + + Uses stack-based algorithm, O(n). When encountering a deeper level + entry, push it as a child of the current stack top; when same or + shallower, pop back to the corresponding level. + + Args: + flat_entries: Flat list of TOCEntry objects with ``level`` set. + + Returns: + List of top-level TOCEntry objects with ``children`` populated. + """ + if not flat_entries: + return [] + + roots: List[TOCEntry] = [] + # Stack holds (level, entry) pairs representing the current path + stack: List[TOCEntry] = [] + + for entry in flat_entries: + # Reset children to avoid stale data from prior processing + entry.children = [] + + # Pop stack until we find the parent (shallower level) + while stack and stack[-1].level >= entry.level: + stack.pop() + + if stack: + # Attach as child of the current stack top + stack[-1].children.append(entry) + else: + # No parent — this is a root-level entry + roots.append(entry) + + stack.append(entry) + + return roots + + @classmethod + async def extract( + cls, + file_path: str, + content: str, + *, + llm_caller: Any | None = None, + total_pages: Optional[int] = None, + ) -> Optional[List[TOCEntry]]: + """Extract TOC using layered fallback strategy. + + Tries extraction methods in order of reliability. Falls back to + the next layer when the current layer yields fewer than + ``_MIN_ENTRIES_THRESHOLD`` entries. + + Args: + file_path: Absolute path to the source file. + content: Extracted text content of the file. + llm_caller: Optional LLM caller for Layer 4. + total_pages: Total page count of the source document, if known. + Used to estimate ``page_start`` for Layer 3/4 entries. + + Returns: + List of TOCEntry with resolved char positions, or None if + no layer produced enough entries. + """ + ext = Path(file_path).suffix.lower() + + result: Optional[TocResult] = None + # Track whether the result came from pypdf (Layer 1) which + # already produces a properly nested tree with children. + is_pypdf = False + + if ext == ".pdf": + result = await cls._extract_pdf_layered( + file_path, content, llm_caller, + ) + if result is not None: + is_pypdf = result.source == "pypdf" + elif ext in (".md", ".markdown"): + heading_result = HeadingTocExtractor.extract(content) + if cls._is_sufficient(heading_result): + result = heading_result + elif ext in (".docx",): + result = DocxTocExtractor.extract(file_path) + elif ext in (".html", ".htm"): + result = HtmlTocExtractor.extract(content) + else: + return None + + if result is None or not cls._is_sufficient(result): + return None + + # Merge total_pages from TocResult if not explicitly provided + if total_pages is None and result.page_count: + total_pages = result.page_count + + entries = result.entries + + # Post-processing for non-pypdf layers: rebuild hierarchy from + # flat level-annotated entries (Layer 2/3/4 and format extractors + # produce flat lists; pypdf already builds a nested tree). + if not is_pypdf: + entries = cls._build_hierarchy(entries) + + # Estimate page_start for Layer 3/4 entries that lack it + if total_pages and content: + flat_all: List[TOCEntry] = [] + cls._flatten_entries(entries, flat_all) + content_len = len(content) + for entry in flat_all: + if entry.page_start is None and entry.char_start is not None: + entry.page_start = min( + total_pages, + max(1, round(entry.char_start / content_len * total_pages) + 1), + ) + + total = cls._count_entries(entries) + if total < cls._MIN_ENTRIES_THRESHOLD: + return None + + # Resolve character positions in the extracted text + entries = cls._resolve_char_positions(entries, content) + return entries + + @classmethod + async def _extract_pdf_layered( + cls, + file_path: str, + content: str, + llm_caller: Any | None, + ) -> Optional[TocResult]: + """Apply layered extraction for PDF files. + + Args: + file_path: Path to the PDF file. + content: Extracted text content. + llm_caller: Optional LLM caller for Layer 4. + + Returns: + Best TocResult from the layer cascade, or None. + """ + # Layer 1: pypdf + result = PypdfOutlineExtractor.extract(file_path) + if cls._is_sufficient(result): + logger.info( + "TOC Layer 1 (pypdf): %d entries for %s", + len(result.entries), Path(file_path).name, + ) + return result + + # Layer 2: pdfminer.six + result = PdfminerOutlineExtractor.extract(file_path) + if cls._is_sufficient(result): + logger.info( + "TOC Layer 2 (pdfminer): %d entries for %s", + len(result.entries), Path(file_path).name, + ) + return result + + # Layer 3: heading detection from content + if content: + result = HeadingTocExtractor.extract(content) + if cls._is_sufficient(result): + logger.info( + "TOC Layer 3 (heading): %d entries for %s", + len(result.entries), Path(file_path).name, + ) + return result + + # Layer 4: LLM-assisted (optional) + if llm_caller is not None and content: + result = await LlmTocExtractor.extract(content, llm_caller) + if cls._is_sufficient(result): + logger.info( + "TOC Layer 4 (LLM): %d entries for %s", + len(result.entries), Path(file_path).name, + ) + return result + + logger.debug( + "TOC extraction: no layer produced sufficient entries for %s", + Path(file_path).name, + ) + return None + + @classmethod + def _is_sufficient(cls, result: Optional[TocResult]) -> bool: + """Check whether a TocResult has enough entries to be useful.""" + if result is None: + return False + return len(result.entries) >= cls._MIN_ENTRIES_THRESHOLD + + # ------------------------------------------------------------------ # + # Character position resolution # + # ------------------------------------------------------------------ # + + @staticmethod + def _resolve_char_positions( + entries: List[TOCEntry], + content: str, + ) -> List[TOCEntry]: + """Resolve character start/end positions for TOC entries in content. + + Searches for each entry's title in the content text using + case-insensitive matching, progressing forward to avoid duplicate + matches. Sets char_end to the start of the next entry (or + len(content) for the last entry). + + Also recurses into children to resolve their positions. + + Args: + entries: Flat list of TOCEntry to resolve. + content: Full extracted text to search within. + + Returns: + The same list with char_start and char_end populated. + """ + if not content or not entries: + return entries + + content_lower = content.lower() + search_from = 0 + + # Collect all entries in document order (top-level + children) + flat: List[TOCEntry] = [] + TOCExtractor._flatten_entries(entries, flat) + + # Pass 1: resolve char_start for each entry + for entry in flat: + title_lower = entry.title.lower().strip() + if not title_lower: + entry.char_start = search_from + continue + # Normalise whitespace for fuzzy matching + title_normalised = re.sub(r"\s+", " ", title_lower) + pos = content_lower.find(title_normalised, search_from) + if pos < 0: + pos = content_lower.find(title_lower, search_from) + if pos >= 0: + entry.char_start = pos + search_from = pos + len(title_lower) + else: + pos = content_lower.find(title_normalised) + if pos < 0: + pos = content_lower.find(title_lower) + if pos >= 0: + entry.char_start = pos + else: + entry.char_start = search_from + + # Pass 2: resolve char_end as start of next entry (or len(content)) + for i in range(len(flat) - 1): + flat[i].char_end = flat[i + 1].char_start + if flat: + flat[-1].char_end = len(content) + + return entries + + @staticmethod + def _flatten_entries( + entries: List[TOCEntry], + flat: List[TOCEntry], + ) -> None: + """Flatten nested TOCEntry tree into document-order list.""" + for entry in entries: + flat.append(entry) + if entry.children: + TOCExtractor._flatten_entries(entry.children, flat) + + @staticmethod + def _count_entries(entries: List[TOCEntry]) -> int: + """Count total entries including nested children.""" + count = 0 + for entry in entries: + count += 1 + if entry.children: + count += TOCExtractor._count_entries(entry.children) + return count diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py new file mode 100644 index 0000000..10cab2b --- /dev/null +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -0,0 +1,1792 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Document tree indexer — PageIndex-inspired hierarchical structure analysis. + +Builds a JSON tree index for structured long documents (PDF, DOCX, MD, HTML) +so that downstream search can navigate via LLM reasoning instead of brute-force +Monte Carlo sampling. +""" + +import json +import math +import os +import re +from collections import Counter +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from sirchmunk.llm.openai_chat import OpenAIChat +from sirchmunk.utils import LogCallback, create_logger +from sirchmunk.utils.file_utils import get_fast_hash + +# File-size threshold: skip tree indexing for small files +_TREE_MIN_CHARS = 10_000 # 10 K characters (lowered from 20K for broader coverage) + +# Adaptive depth thresholds: (min_chars, max_depth) — evaluated top-down; +# **must** be sorted by min_chars descending so the first match wins. +_TREE_ADAPTIVE_DEPTH_THRESHOLDS: tuple = ( + (100_000, 4), + (50_000, 3), + (20_000, 2), +) + +# Summary snippet length extracted from section content (chars) +_TOC_NODE_SUMMARY_MAX_CHARS = 300 + +# Marker substring length for fuzzy fallback matching in _resolve_positions +_MARKER_SUBSTRING_LEN = 32 + +# Maximum span ratio: filter out overly large spans (>80% of document) +_MAX_SPAN_RATIO = 0.8 + +# Adaptive preview window for LLM structure analysis +_TREE_PREVIEW_MIN = 12_000 # Minimum preview window (chars) +_TREE_PREVIEW_MAX = 50_000 # Maximum preview window (~12K tokens) +_TREE_PREVIEW_RATIO = 0.15 # Fraction of document to preview + +# Structured content detection thresholds (Plan 1: generic table recognition) +_STRUCT_MD_TABLE_MIN_ROWS = 3 # Min markdown table rows to classify as structured +_STRUCT_NUMERIC_DENSITY_THRESHOLD = 0.20 # Fraction of numeric tokens in a text segment + +# Extensions eligible for tree indexing +_TREE_EXTENSIONS = { + ".pdf", ".docx", ".doc", ".md", ".markdown", + ".html", ".htm", ".rst", ".tex", ".txt", +} + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + +@dataclass +class TreeNode: + """Single node in the document tree.""" + + node_id: str + title: str + summary: str + char_range: Tuple[int, int] # [start, end) in the extracted text + level: int = 0 + page_range: Optional[Tuple[int, int]] = None + children: List["TreeNode"] = field(default_factory=list) + table_count: int = 0 # Number of tables associated with this node's page range + content_type: str = "text" # "text" | "table" + + def to_dict(self) -> Dict[str, Any]: + return { + "node_id": self.node_id, + "title": self.title, + "summary": self.summary, + "char_range": list(self.char_range), + "level": self.level, + "page_range": list(self.page_range) if self.page_range else None, + "children": [c.to_dict() for c in self.children], + "table_count": self.table_count, + "content_type": self.content_type, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TreeNode": + children = [cls.from_dict(c) for c in data.get("children", [])] + pr = data.get("page_range") + return cls( + node_id=data["node_id"], + title=data["title"], + summary=data["summary"], + char_range=tuple(data["char_range"]), + level=data.get("level", 0), + page_range=tuple(pr) if pr else None, + children=children, + table_count=data.get("table_count", 0), + content_type=data.get("content_type", "text"), + ) + + @property + def leaf(self) -> bool: + return len(self.children) == 0 + + def all_leaves(self) -> List["TreeNode"]: + """Return all leaf nodes under this subtree.""" + if self.leaf: + return [self] + leaves: List["TreeNode"] = [] + for c in self.children: + leaves.extend(c.all_leaves()) + return leaves + + +@dataclass +class DocumentTree: + """Complete tree index for a single document.""" + + file_path: str + file_hash: str + created_at: str + total_chars: int + total_pages: Optional[int] = None + root: Optional[TreeNode] = None + + def to_json(self) -> str: + return json.dumps({ + "file_path": self.file_path, + "file_hash": self.file_hash, + "created_at": self.created_at, + "total_chars": self.total_chars, + "total_pages": self.total_pages, + "root": self.root.to_dict() if self.root else None, + }, ensure_ascii=False, indent=2) + + @classmethod + def from_json(cls, json_str: str) -> "DocumentTree": + data = json.loads(json_str) + root = TreeNode.from_dict(data["root"]) if data.get("root") else None + return cls( + file_path=data["file_path"], + file_hash=data["file_hash"], + created_at=data["created_at"], + total_chars=data["total_chars"], + total_pages=data.get("total_pages"), + root=root, + ) + + +# --------------------------------------------------------------------------- +# Indexer +# --------------------------------------------------------------------------- + +class DocumentTreeIndexer: + """Build and cache PageIndex-style hierarchical tree indices for documents.""" + + # Maximum child nodes before switching to paginated LLM selection. + # Balance: lower = more LLM calls, higher = more tokens per call. + _PAGE_SIZE_THRESHOLD: int = 15 + + # Number of nodes per group in paginated selection. + _GROUP_PAGE_SIZE: int = 15 + + # Minimum navigation depth before allowing early termination. + _NAV_MIN_DEPTH: int = 2 + + def __init__( + self, + llm: OpenAIChat, + cache_dir: Union[str, Path], + log_callback: LogCallback = None, + ): + self._llm = llm + self._cache_dir = Path(cache_dir) + self._cache_dir.mkdir(parents=True, exist_ok=True) + self._log = create_logger(log_callback=log_callback) + + # ------------------------------------------------------------------ # + # Public API # + # ------------------------------------------------------------------ # + + async def build_tree( + self, + file_path: str, + content: str, + *, + max_depth: int = 4, + force_rebuild: bool = False, + total_pages: Optional[int] = None, + toc_entries: Optional[List[Any]] = None, + ) -> Optional[DocumentTree]: + """Build a tree index for a document. + + When *toc_entries* are provided (from TOCExtractor), uses the + TOC-accelerated path that skips recursive LLM analysis and builds + the tree directly from extracted headings. + + Returns None when the document is too small or unstructured. + """ + file_hash = get_fast_hash(file_path) + if file_hash is None: + return None + + if not force_rebuild: + cached = self._load_cache(file_hash) + if cached is not None: + await self._log.info(f"[TreeIndexer] Cache hit for {Path(file_path).name}") + return cached + + if len(content) < _TREE_MIN_CHARS: + return None + + ext = Path(file_path).suffix.lower() + if ext not in _TREE_EXTENSIONS: + return None + + # Use adaptive depth based on document length + effective_depth = self._compute_adaptive_depth(len(content)) + + await self._log.info( + f"[TreeIndexer] Building tree for {Path(file_path).name} " + f"({len(content)} chars, depth={effective_depth})" + ) + + # TOC-accelerated path: skip recursive LLM analysis + if toc_entries: + root = await self._build_tree_from_toc( + toc_entries, content, total_pages=total_pages, + ) + if root is not None: + # NOTE: _deepen_large_leaves disabled - char_range anchoring via LLM start_text + # is unreliable, causing overlapping ranges and search failures. + # TODO: Re-enable when robust char_range calculation is implemented. + # await self._deepen_large_leaves(root, content, max_depth=effective_depth) + # Node summary enrichment: controlled by SIRCHMUNK_SKIP_NODE_SUMMARIES env var. + # Set to "true" to skip during debugging / performance testing. + _skip_summaries = os.getenv("SIRCHMUNK_SKIP_NODE_SUMMARIES", "").lower() in ("true", "1", "yes") + print(f"SEARCH_WIKI_DEBUG [T1] enrich_node_summaries (TOC path): skip={_skip_summaries}, env={os.getenv('SIRCHMUNK_SKIP_NODE_SUMMARIES', '')}", flush=True) + if not _skip_summaries: + await self._enrich_node_summaries(root, content) + tree = DocumentTree( + file_path=file_path, + file_hash=file_hash, + created_at=datetime.now(timezone.utc).isoformat(), + total_chars=len(content), + total_pages=total_pages, + root=root, + ) + self._save_cache(file_hash, tree) + await self._log.info( + f"[TreeIndexer] Built tree from TOC: {self._count_nodes(root)} nodes" + ) + return tree + + # Fallback: existing recursive LLM path (with adaptive depth) + root = await self._build_node(content, level=0, max_depth=effective_depth) + if root is None: + return None + + # NOTE: _deepen_large_leaves disabled - char_range anchoring via LLM start_text + # is unreliable, causing overlapping ranges and search failures. + # TODO: Re-enable when robust char_range calculation is implemented. + # await self._deepen_large_leaves(root, content, max_depth=effective_depth) + # Node summary enrichment: controlled by SIRCHMUNK_SKIP_NODE_SUMMARIES env var. + # Set to "true" to skip during debugging / performance testing. + _skip_summaries = os.getenv("SIRCHMUNK_SKIP_NODE_SUMMARIES", "").lower() in ("true", "1", "yes") + print(f"SEARCH_WIKI_DEBUG [T1] enrich_node_summaries (recursive path): skip={_skip_summaries}, env={os.getenv('SIRCHMUNK_SKIP_NODE_SUMMARIES', '')}", flush=True) + if not _skip_summaries: + await self._enrich_node_summaries(root, content) + + tree = DocumentTree( + file_path=file_path, + file_hash=file_hash, + created_at=datetime.now(timezone.utc).isoformat(), + total_chars=len(content), + total_pages=total_pages, + root=root, + ) + self._save_cache(file_hash, tree) + await self._log.info( + f"[TreeIndexer] Built tree: {self._count_nodes(root)} nodes, " + f"depth={self._max_node_depth(root)}" + ) + return tree + + async def navigate( + self, + tree: DocumentTree, + query: str, + *, + max_results: int = 3, + max_depth: int = 4, + min_depth: int = 1, + ) -> List[TreeNode]: + """Adaptive-depth LLM-driven tree navigation. + + Iteratively descends the tree using _select_children() at each level, + collecting leaf nodes until *max_results* are found or *max_depth* is + reached. Enforces *min_depth* descent before allowing early + termination to avoid overly shallow results. + + Args: + tree: DocumentTree with a root node. + query: Search query for relevance selection. + max_results: Maximum number of leaf nodes to return. + max_depth: Maximum descent depth (default 4). + min_depth: Minimum depth before early termination (default 1). + + Returns: + List of the most relevant leaf TreeNodes. + """ + if tree.root is None: + return [] + + print(f"SEARCH_WIKI_DEBUG [T2] navigate: query={query[:80]}, total_nodes={self._count_nodes(tree.root)}", flush=True) + + candidates = tree.root.children if tree.root.children else [tree.root] + if not candidates: + return [tree.root] + + # Skip single-child container chains (e.g. SEC boilerplate wrappers + # like "UNITED STATES SECURITIES AND EXCHANGE COMMISSION" → "FORM 10-K") + # to avoid wasting navigation depth on structural-only nodes. + while ( + len(candidates) == 1 + and candidates[0].children + and not candidates[0].leaf + ): + candidates = candidates[0].children + + # Adaptive min-depth: clamp to tree's actual depth + tree_max_depth = self._max_node_depth(tree.root) + effective_min_depth = min(min_depth, max(tree_max_depth - 1, 1)) + + result_leaves: List[TreeNode] = [] + visited: set = set() # prevent cycles + frontier = candidates + selected: List[TreeNode] = [] + + depth = 0 + while depth < max_depth and frontier: + selected = await self._select_children( + frontier, query, max_selections=max_results, + ) + print(f"SEARCH_WIKI_DEBUG [T3] navigate layer: depth={depth}, selected={len(selected)}, names={[n.title[:30] for n in selected][:5]}", flush=True) + + if not selected: + # Fix A.1: when depth < effective_min_depth, expand all frontier children + if depth < effective_min_depth: + next_frontier: List[TreeNode] = [] + for node in frontier: + if node.children: + next_frontier.extend(node.children) + else: + result_leaves.append(node) + if not next_frontier: + break + frontier = next_frontier + depth += 1 + continue + break + + next_frontier: List[TreeNode] = [] + for node in selected: + node_id = id(node) + if node_id in visited: + continue + visited.add(node_id) + + if node.children: + next_frontier.extend(node.children) + else: + result_leaves.append(node) + + # Fix A.3: early termination requires depth >= effective_min_depth + if len(result_leaves) >= max_results and depth >= effective_min_depth: + break + + # Fix A.4: check for empty next_frontier + if not next_frontier: + break + frontier = next_frontier + depth += 1 + + # Fallback: if no leaves found, expand last selected nodes + if not result_leaves and selected: + for node in selected: + result_leaves.extend(node.all_leaves()[:max_results]) + + # Deduplicate and cap + seen_ids: set = set() + unique: List[TreeNode] = [] + for n in result_leaves: + if n.node_id not in seen_ids: + seen_ids.add(n.node_id) + unique.append(n) + leaves = unique[:max_results] + _page_valid = sum(1 for l in leaves if getattr(l, 'page_range', None) and len(l.page_range) == 2 and l.page_range[0]) + print(f"SEARCH_WIKI_DEBUG [T4] navigate result: leaves={len(leaves)}, page_range_valid={_page_valid}", flush=True) + return leaves + + def load_tree(self, file_path: str) -> Optional[DocumentTree]: + """Load a cached tree index for the given file (sync).""" + file_hash = get_fast_hash(file_path) + if file_hash is None: + return None + return self._load_cache(file_hash) + + def has_tree(self, file_path: str) -> bool: + """Check whether a cached tree index exists for the file.""" + file_hash = get_fast_hash(file_path) + if file_hash is None: + return False + return self._cache_path(file_hash).exists() + + # ------------------------------------------------------------------ # + # Internals # + # ------------------------------------------------------------------ # + + async def _build_tree_from_toc( + self, + toc_entries: List[Any], + content: str, + *, + total_pages: Optional[int] = None, + ) -> Optional[TreeNode]: + """Build tree directly from extracted TOC entries, avoiding recursive LLM. + + Each TOCEntry becomes a TreeNode with char_range from the entry positions. + Only the root summary requires an LLM call (_synthesize_root_summary). + + Args: + toc_entries: List of TOCEntry from toc_extractor. + content: Full extracted text of the document. + total_pages: Total page count for page_range calculation. + + Returns: + Root TreeNode, or None if no children could be created. + """ + # Infer hierarchy when TOC entries are flat (all same level) + toc_entries = self._infer_hierarchy(toc_entries) + + # Merge consecutive fragment entries into virtual parents + toc_entries = self._merge_fragment_entries(toc_entries) + + # Plan 4: Group disproportionately large tail entries (exhibits/appendices) + toc_entries = self._merge_supplementary_entries(toc_entries) + + seen_ids: set = set() + children = self._toc_entries_to_nodes( + toc_entries, content, len(content), seen_ids, + fallback_level=1, total_pages=total_pages, + ) + + if not children: + return None + + root_summary = await self._synthesize_root_summary(children) + root_page_range = (1, total_pages) if total_pages and total_pages > 0 else None + return TreeNode( + node_id=self._unique_node_id(0, seen_ids), + title="Document", + summary=root_summary, + char_range=(0, len(content)), + level=0, + page_range=root_page_range, + children=children, + ) + + @staticmethod + def _merge_supplementary_entries(entries: List[Any]) -> List[Any]: + """Merge tail entries with disproportionately large spans into a virtual parent. + + Detects when the last few entries collectively span much more content + than the preceding entries — a generic structural signal for exhibits, + appendices, or attachment sections. Groups them under a single + navigable node to prevent them from dominating tree navigation. + + Uses only structural signals (char span ratios, position in document) + — no domain-specific keywords. Returns original entries when the + structural pattern is not detected or when too few entries remain. + """ + if len(entries) < 4: + return entries + + def _span(e: Any) -> int: + if hasattr(e, 'char_start') and hasattr(e, 'char_end'): + if e.char_end and e.char_start is not None: + return max(0, e.char_end - e.char_start) + return 0 + + spans = [_span(e) for e in entries] + total_span = sum(spans) + if total_span == 0: + return entries + + # Scan backwards to find tail entries whose cumulative span is + # disproportionately large while individually being much larger + # than the body-section baseline. Uses 25th percentile instead of + # median so that many large tail entries cannot inflate the baseline. + non_zero_spans = [s for s in spans if s > 0] + if len(non_zero_spans) < 4: + return entries + sorted_spans = sorted(non_zero_spans) + q25_idx = max(0, len(sorted_spans) // 4) + baseline_span = sorted_spans[q25_idx] + + tail_start = len(entries) + cumulative = 0 + for i in range(len(entries) - 1, 0, -1): + if spans[i] > baseline_span * 3: + cumulative += spans[i] + tail_start = i + else: + break + + tail_count = len(entries) - tail_start + # Require at least 2 tail entries spanning > 40% of total content + if tail_count < 2 or cumulative / total_span < 0.40: + return entries + + # Also ensure enough primary entries remain + if tail_start < 2: + return entries + + from copy import deepcopy + first_tail = entries[tail_start] + last_tail = entries[-1] + merged = deepcopy(first_tail) + merged.title = f"Supplementary Material ({tail_count} sections)" + if hasattr(last_tail, 'char_end') and last_tail.char_end: + merged.char_end = last_tail.char_end + merged.children = list(entries[tail_start:]) + + result = list(entries[:tail_start]) + [merged] + return result if len(result) >= 2 else entries + + @staticmethod + def _merge_fragment_entries(entries: List[Any]) -> List[Any]: + """Merge consecutive fragment TOC entries into virtual parent nodes. + + Detects runs of >=3 consecutive entries that have tiny char_range + spans (<500) and no children, then collapses them into a single + virtual 'Preamble' entry. Uses only structural signals (char spans, + children counts) — no domain-specific keywords. + + Safety valve: returns original *entries* if result has < 2 entries. + """ + if len(entries) <= 5: + return entries + + # Phase 1: Detect fragment runs + def _is_fragment(e: Any) -> bool: + span = 0 + if hasattr(e, 'char_start') and hasattr(e, 'char_end'): + if e.char_end and e.char_start is not None: + span = e.char_end - e.char_start + has_children = bool(getattr(e, 'children', None)) + return span < 500 and not has_children + + # Find runs of consecutive fragments + runs: List[List[int]] = [] # list of [start_idx, end_idx] inclusive + i = 0 + while i < len(entries): + if _is_fragment(entries[i]): + run_start = i + while i < len(entries) and _is_fragment(entries[i]): + i += 1 + if (i - run_start) >= 3: # Only merge runs of 3+ + runs.append([run_start, i - 1]) + else: + i += 1 + + if not runs: + return entries + + # Phase 2: Merge each run into a virtual parent + from copy import deepcopy + + result: List[Any] = [] + prev_end = -1 + for run_start, run_end in runs: + # Add non-fragment entries before this run + for j in range(prev_end + 1, run_start): + result.append(entries[j]) + + # Create virtual parent from the run + first_entry = entries[run_start] + last_entry = entries[run_end] + + merged = deepcopy(first_entry) + merged.title = f"Preamble ({run_end - run_start + 1} sections)" + if hasattr(last_entry, 'char_end') and last_entry.char_end: + merged.char_end = last_entry.char_end + # Set children to the original entries + merged.children = list(entries[run_start:run_end + 1]) + result.append(merged) + prev_end = run_end + + # Add remaining entries after last run + for j in range(prev_end + 1, len(entries)): + result.append(entries[j]) + + # Safety valve + if len(result) < 2: + return entries + + return result + + @staticmethod + def _toc_entries_to_nodes( + entries: List[Any], + content: str, + parent_end: int, + seen_ids: set, + fallback_level: int, + total_pages: Optional[int] = None, + ) -> List["TreeNode"]: + """Recursively convert TOCEntry trees into TreeNode trees. + + Handles arbitrary nesting depth and guards against invalid + char_start / char_end values. Computes ``page_range`` using a + look-ahead algorithm when ``page_start`` is available on entries. + + Args: + entries: List of TOCEntry objects (may have children). + content: Full extracted text. + parent_end: End offset inherited from the parent node. + seen_ids: Set for unique node-id generation. + fallback_level: Default level when entry.level is 0. + total_pages: Total page count for page_range look-ahead. + """ + nodes: List[TreeNode] = [] + content_len = len(content) + for i, entry in enumerate(entries): + start = max(0, min(entry.char_start, content_len)) + end = entry.char_end if entry.char_end and entry.char_end > start else parent_end + end = min(end, content_len) + + section_text = content[start:min(start + _TOC_NODE_SUMMARY_MAX_CHARS, end)] + nid = DocumentTreeIndexer._unique_node_id(start, seen_ids) + level = entry.level if entry.level > 0 else fallback_level + + # page_range: look-ahead algorithm + page_range = None + if hasattr(entry, 'page_start') and entry.page_start is not None: + # Find next sibling with page_start to determine page_end + page_end = total_pages or entry.page_start + for j in range(i + 1, len(entries)): + if hasattr(entries[j], 'page_start') and entries[j].page_start is not None: + page_end = entries[j].page_start + break + page_range = (entry.page_start, max(entry.page_start, page_end)) + + child_nodes: List[TreeNode] = [] + if entry.children: + child_nodes = DocumentTreeIndexer._toc_entries_to_nodes( + entry.children, content, end, seen_ids, + fallback_level=level + 1, + total_pages=total_pages, + ) + + # Plan 1: Detect structured/tabular content and add navigation hint + # to help LLM-driven navigation prioritize data-rich sections. + # Deliberately keeps content_type="text" so _classify_leaves + # routes to kreuzberg char_range (higher fidelity than pypdf). + summary_text = section_text.strip() + section_sample = content[start:min(start + 2000, end)] + if DocumentTreeIndexer._detect_structured_content(section_sample): + summary_text = f"[Data/Tables] {summary_text}" + + node = TreeNode( + node_id=nid, + title=entry.title, + summary=summary_text, + char_range=(start, end), + level=level, + page_range=page_range, + children=child_nodes, + ) + nodes.append(node) + return nodes + + @staticmethod + def _unique_node_id(start: int, seen_ids: set) -> str: + """Generate a unique node_id based on char offset, appending a + disambiguator when collisions occur.""" + base = f"N{start:06d}" + if base not in seen_ids: + seen_ids.add(base) + return base + suffix = 1 + while f"{base}_{suffix}" in seen_ids: + suffix += 1 + nid = f"{base}_{suffix}" + seen_ids.add(nid) + return nid + + @staticmethod + def _compute_adaptive_depth(content_length: int) -> int: + """Compute max tree depth based on document length. + + Longer documents get deeper trees for finer-grained navigation. + Uses _TREE_ADAPTIVE_DEPTH_THRESHOLDS for threshold-based selection. + + Args: + content_length: Character count of the document. + + Returns: + Maximum tree depth (2-4). + """ + for threshold, depth in _TREE_ADAPTIVE_DEPTH_THRESHOLDS: + if content_length >= threshold: + return depth + return 2 # minimum depth + + @staticmethod + def _detect_structured_content(text: str, sample_size: int = 2000) -> bool: + """Detect whether text contains structured/tabular data using generic signals. + + Uses two high-precision, domain-agnostic heuristics (any triggers True): + 1. Markdown table syntax (pipe-delimited rows with separator line) + 2. High numeric token density (currency, percentages, large numbers) + + Intentionally omits lower-precision signals (multi-space alignment, + tab counts) because PDF-extracted text frequently has irregular + spacing that causes false positives. + + Args: + text: Content segment to analyze. + sample_size: Max chars to analyze (avoids scanning huge sections). + """ + sample = text[:sample_size] + if not sample.strip(): + return False + + # Signal 1: Markdown table syntax — pipe-separated rows with header separator + pipe_lines = [ln for ln in sample.split("\n") if ln.strip().startswith("|")] + separator_lines = [ln for ln in pipe_lines if re.match(r"\|\s*[-:]+", ln)] + data_rows = len(pipe_lines) - len(separator_lines) + if data_rows >= _STRUCT_MD_TABLE_MIN_ROWS and separator_lines: + return True + + # Signal 2: Numeric token density — high ratio of numeric-pattern tokens + non_ws = re.sub(r"\s+", "", sample) + if len(non_ws) > 50: + from sirchmunk.learnings.compiler import _NUM_TOKEN_RE + num_tokens = _NUM_TOKEN_RE.findall(sample) + total_chars = sum(len(t) for t in num_tokens) + if total_chars / len(non_ws) >= _STRUCT_NUMERIC_DENSITY_THRESHOLD: + return True + + return False + + async def _build_node( + self, text: str, level: int, max_depth: int, + offset: int = 0, + ) -> Optional[TreeNode]: + """Recursively build tree nodes via LLM structure analysis.""" + from sirchmunk.llm.prompts import COMPILE_TREE_STRUCTURE + + preview_size = self._compute_preview_size(len(text)) + preview = text[:preview_size] + prompt = COMPILE_TREE_STRUCTURE.format( + document_content=preview, + max_sections=8, + ) + + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + sections = self._parse_sections(resp.content, text) + + if not sections: + return TreeNode( + node_id=f"N{offset:06d}", + title="Document", + summary=text[:300], + char_range=(offset, offset + len(text)), + level=level, + ) + + children: List[TreeNode] = [] + for i, sec in enumerate(sections): + child = TreeNode( + node_id=f"N{sec['start'] + offset:06d}", + title=sec["title"], + summary=sec["summary"], + char_range=(sec["start"] + offset, sec["end"] + offset), + level=level + 1, + ) + section_text = text[sec["start"]:sec["end"]] + if level + 1 < max_depth and len(section_text) > _TREE_MIN_CHARS: + deeper = await self._build_node( + section_text, level + 1, max_depth, offset=sec["start"] + offset, + ) + if deeper and deeper.children: + child.children = deeper.children + children.append(child) + + root_summary = await self._synthesize_root_summary(children) + + return TreeNode( + node_id=f"N{offset:06d}", + title="Document", + summary=root_summary, + char_range=(offset, offset + len(text)), + level=level, + children=children, + ) + + @staticmethod + def _collect_representative_nodes( + children: List[TreeNode], + max_nodes: int = 15, + ) -> List[TreeNode]: + """Collect representative nodes from multiple tree depths. + + Gathers direct children plus a sample of deeper descendants to + ensure the summary captures actual content topics — not just + top-level structural wrappers that may be uninformative. + + Strategy: + - Layer 1: all direct children (structural overview). + - Layer 2+: BFS preferring **leaf nodes** (actual content topics) + over intermediate nodes (whose summaries overlap children). + """ + reps: List[TreeNode] = [] + seen: set = set() + + # Layer 1: all direct children (even wrappers — they provide structure) + for c in children: + if c.node_id not in seen and len(reps) < max_nodes: + reps.append(c) + seen.add(c.node_id) + + # Layer 2+: BFS collecting leaf nodes with substantive summaries. + # Leaf nodes represent actual content sections; intermediate nodes + # often have summaries that redundantly overlap their children. + queue = [] + for c in children: + for gc in c.children: + queue.append(gc) + + while queue and len(reps) < max_nodes: + node = queue.pop(0) + if node.node_id in seen: + continue + + is_leaf = not node.children + has_substance = ( + (node.summary and len(node.summary.strip()) > 20) + or node.table_count > 0 + ) + + if is_leaf and has_substance: + reps.append(node) + seen.add(node.node_id) + elif not is_leaf: + # Expand intermediate nodes without adding them — + # their content is represented by their leaf descendants. + for ch in node.children: + queue.append(ch) + + return reps + + async def _synthesize_root_summary(self, children: List[TreeNode]) -> str: + """Synthesize a document-level summary from multi-depth section info. + + Gathers representative nodes from multiple tree depths to produce + a summary that reflects actual document content, not just top-level + wrapper headings like "SEC Filing" or "Table of Contents". + """ + if not children: + return "" + from sirchmunk.llm.prompts import COMPILE_SYNTHESIZE_SUMMARY + representatives = self._collect_representative_nodes(children) + sections_text = "\n".join( + f"- {n.title}: {n.summary}" for n in representatives + ) + prompt = COMPILE_SYNTHESIZE_SUMMARY.format(sections=sections_text) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + return resp.content.strip() + + def _parse_sections( + self, llm_output: str, full_text: str, + ) -> List[Dict[str, Any]]: + """Parse LLM section output into [{title, summary, start, end}, ...].""" + # Try JSON array first + try: + raw = llm_output + # Strip markdown fences + raw = re.sub(r"^```(?:json)?\s*", "", raw, flags=re.MULTILINE) + raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE).strip() + m = re.search(r"\[.*\]", raw, re.DOTALL) + if m: + items = json.loads(m.group()) + return self._resolve_positions(items, full_text) + except (json.JSONDecodeError, TypeError): + pass + return [] + + @staticmethod + def _resolve_positions( + items: List[Dict[str, Any]], full_text: str, + ) -> List[Dict[str, Any]]: + """Resolve section start/end character offsets from marker text. + + Two-pass algorithm: + Pass 1 — determine all start positions with tiered fallback: + exact match from prev_end -> substring match -> full-text fallback. + Pass 2 — set end[i] = start[i+1]; last end = text_len. + + Filters out invalid spans and overly large spans (> ``_MAX_SPAN_RATIO`` + of the document) to prevent accumulated positioning errors. + """ + text_lower = full_text.lower() + text_len = len(full_text) + resolved: List[Dict[str, Any]] = [] + + # Pass 1: determine all start positions + prev_end = 0 + for item in items: + title = item.get("title", "") + marker = item.get("start_marker", title) + + pos = -1 + if marker: + marker_lower = marker.lower() + # Level 1: exact match from prev_end + pos = text_lower.find(marker_lower, prev_end) + # Level 2: substring match (first N chars) from prev_end + if pos < 0 and len(marker_lower) > _MARKER_SUBSTRING_LEN: + pos = text_lower.find( + marker_lower[:_MARKER_SUBSTRING_LEN], prev_end, + ) + # Level 3: full text fallback from start + if pos < 0: + pos = text_lower.find(marker_lower, 0) + + start = pos if pos >= 0 else prev_end + resolved.append({ + "title": title, + "summary": item.get("summary", ""), + "start": start, + "end": text_len, # placeholder + }) + prev_end = ( + start + max(1, len(marker)) + if pos >= 0 + else prev_end + ) + + # Pass 2: set end[i] = start[i+1], last end = text_len + for i in range(len(resolved) - 1): + resolved[i]["end"] = resolved[i + 1]["start"] + if resolved: + resolved[-1]["end"] = text_len + + # Filter out invalid spans and overly large spans + return [ + s for s in resolved + if s["end"] > s["start"] + and (s["end"] - s["start"]) / max(text_len, 1) < _MAX_SPAN_RATIO + ] + + @staticmethod + def _filter_low_value_nodes( + nodes: List["TreeNode"], + *, + min_remaining: int = 3, + ) -> List["TreeNode"]: + """Remove only structurally empty or exact-duplicate nodes. + + Intentionally conservative: the LLM selection step receives rich + structural descriptors (page span, table count, subsection count) + and is trusted to judge relevance. This filter removes only + definitive noise that would waste LLM context: + + 1. Empty placeholders — no title, no children, zero char span, + and no summary. + 2. Exact duplicates — identical (title, page_range) pairs; among + duplicates the node with the richest structure is kept. + + Safety: returns original *nodes* when fewer than *min_remaining* + would survive. + """ + if len(nodes) <= min_remaining: + return nodes + + keep: List[bool] = [True] * len(nodes) + + def _char_span(n: "TreeNode") -> int: + cr = getattr(n, "char_range", (0, 0)) + return (cr[1] - cr[0]) if cr and len(cr) == 2 else 0 + + # Pass 1: remove structurally empty placeholder nodes + for i, n in enumerate(nodes): + title = (n.title or "").strip() + if not title and not n.children and _char_span(n) == 0 and not n.summary: + keep[i] = False + + # Pass 2: deduplicate exact (title, page_range) pairs — + # keep the node with more structural richness. + seen: dict = {} # (title, page_range_key) → index + for i, n in enumerate(nodes): + if not keep[i]: + continue + title = (n.title or "").strip() + pr = getattr(n, "page_range", None) + pr_key = (pr[0], pr[1]) if pr and len(pr) == 2 else None + dup_key = (title, pr_key) + if dup_key in seen: + prev_i = seen[dup_key] + prev = nodes[prev_i] + richness = (len(n.children), getattr(n, "table_count", 0), _char_span(n)) + prev_richness = (len(prev.children), getattr(prev, "table_count", 0), _char_span(prev)) + if richness > prev_richness: + keep[prev_i] = False + seen[dup_key] = i + else: + keep[i] = False + else: + seen[dup_key] = i + + filtered = [n for i, n in enumerate(nodes) if keep[i]] + return filtered if len(filtered) >= min_remaining else nodes + + @staticmethod + def _build_node_descriptor(node: "TreeNode", index: int) -> str: + """Build a rich descriptor string for a single tree node. + + Includes structural signals: page span, table count, subsection + count, and depth information to help LLM make informed selections. + """ + parts = [f"[{index}] {node.title}"] + + # Page range with span + pr = getattr(node, 'page_range', None) + if pr and len(pr) == 2 and pr[0] is not None: + span_pages = pr[1] - pr[0] + 1 if pr[1] else 1 + parts.append(f"[pages {pr[0]}-{pr[1]}, {span_pages}p]") + + # Table count + if node.table_count > 0: + parts.append(f"[{node.table_count} tables]") + + # Subsections + child_count = len(node.children) + if child_count > 0: + parts.append(f"[{child_count} subsections]") + + # Summary + summary = (node.summary or "")[:200] + if summary: + parts.append(f": {summary}") + + return " ".join(parts) + + @staticmethod + def _build_selection_prompt( + nodes: List["TreeNode"], + query: str, + max_selections: int, + ) -> str: + """Build unified LLM prompt for branch selection. + + Uses structural signals to guide LLM toward high-value sections: + tables, subsection depth, page span. No domain-specific keywords. + """ + listing = "\n".join( + DocumentTreeIndexer._build_node_descriptor(n, i) + for i, n in enumerate(nodes) + ) + + sel_hint = f"1-{min(max_selections, len(nodes))}" + + return ( + f"Given the query: \"{query}\"\n\n" + f"Select the {sel_hint} most relevant sections (by index number):\n" + f"{listing}\n\n" + f"Selection criteria:\n" + f"- Prioritize sections most likely to answer the query\n" + f"- Sections with tables, data, or subsections are often high-value\n" + f"- Short sections containing relevant data should not be dismissed\n" + f"- When uncertain, prefer larger sections that can be narrowed later\n\n" + f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + ) + + async def _select_children( + self, nodes: List[TreeNode], query: str, *, max_selections: int = 3, + ) -> List[TreeNode]: + """LLM-driven branch selection: pick the most relevant children. + + Removes only definitive noise (empty / duplicate nodes), then + dispatches to paginated selection when *nodes* exceeds + ``_PAGE_SIZE_THRESHOLD``. Relevance judgment is delegated to the LLM. + """ + if len(nodes) <= 2: + return nodes + + # Pre-filter low-value fragment nodes + nodes = self._filter_low_value_nodes(nodes) + if len(nodes) <= 2: + return nodes + + if len(nodes) > self._PAGE_SIZE_THRESHOLD: + return await self._select_children_paginated( + nodes, query, max_selections=max_selections, + ) + + prompt = self._build_selection_prompt(nodes, query, max_selections) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + try: + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + indices = json.loads(m.group()) + selected = [nodes[i] for i in indices if 0 <= i < len(nodes)] + return selected if selected else nodes[:max_selections] + except (json.JSONDecodeError, IndexError, TypeError): + pass + return nodes[:max_selections] + + async def _select_children_paginated( + self, + nodes: List[TreeNode], + query: str, + *, + page_size: int = 15, + max_selections: int = 3, + ) -> List[TreeNode]: + """Two-phase paginated selection for large node sets. + + Phase 1: partition *nodes* into sequential groups of *page_size*, + present group summaries to LLM, and select 1-2 groups. + Phase 2: run fine-grained selection within each chosen group. + + Falls back to the first *max_selections* nodes on any LLM failure. + """ + page_size = max(page_size, self._GROUP_PAGE_SIZE) + + # --- Phase 0: build groups --- + groups: List[List[TreeNode]] = [] + for start in range(0, len(nodes), page_size): + groups.append(nodes[start:start + page_size]) + + if len(groups) <= 1: + # Only one group — skip directly to fine-grained selection + return await self._select_from_group(nodes, query, max_selections) + + # --- Phase 1: group-level selection --- + group_listing = "\n".join( + f"[{i}] {g[0].title} ... {g[-1].title} ({len(g)} sections)" + for i, g in enumerate(groups) + ) + group_prompt = ( + f"Given the query: \"{query}\"\n\n" + f"The document has {len(nodes)} sections organized into " + f"{len(groups)} groups.\n" + f"Select the 1-2 most relevant groups (by index number):\n" + f"{group_listing}\n\n" + f"Return ONLY a JSON array of group index numbers, e.g. [0, 2]" + ) + + selected_groups: List[List[TreeNode]] = [] + try: + resp = await self._llm.achat( + [{"role": "user", "content": group_prompt}], + ) + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + g_indices = json.loads(m.group()) + selected_groups = [ + groups[i] for i in g_indices if 0 <= i < len(groups) + ] + except (json.JSONDecodeError, IndexError, TypeError): + pass + + if not selected_groups: + # Fallback: take the first group + selected_groups = [groups[0]] + + # --- Phase 2: fine-grained selection within chosen groups --- + results: List[TreeNode] = [] + for group in selected_groups: + picked = await self._select_from_group(group, query, max_selections) + results.extend(picked) + + # Deduplicate by node_id and cap + seen: set = set() + unique: List[TreeNode] = [] + for n in results: + if n.node_id not in seen: + seen.add(n.node_id) + unique.append(n) + return unique[:max_selections] if unique else nodes[:max_selections] + + async def _select_from_group( + self, + group: List[TreeNode], + query: str, + max_selections: int, + ) -> List[TreeNode]: + """Select the most relevant nodes within a single group via LLM.""" + if len(group) <= 2: + return group + + prompt = self._build_selection_prompt(group, query, max_selections) + try: + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + indices = json.loads(m.group()) + selected = [group[i] for i in indices if 0 <= i < len(group)] + if selected: + return selected[:max_selections] + except (json.JSONDecodeError, IndexError, TypeError): + pass + return group[:max_selections] + + # ------------------------------------------------------------------ # + # Cache I/O # + # ------------------------------------------------------------------ # + + def _cache_path(self, file_hash: str) -> Path: + return self._cache_dir / f"{file_hash}.json" + + def _save_cache(self, file_hash: str, tree: DocumentTree) -> None: + path = self._cache_path(file_hash) + path.write_text(tree.to_json(), encoding="utf-8") + print(f"SEARCH_WIKI_DEBUG [C5] tree_json_saved: path={path}", flush=True) + + def _load_cache(self, file_hash: str) -> Optional[DocumentTree]: + path = self._cache_path(file_hash) + if not path.exists(): + return None + try: + return DocumentTree.from_json(path.read_text(encoding="utf-8")) + except Exception: + return None + + # ------------------------------------------------------------------ # + # Helpers # + # ------------------------------------------------------------------ # + + @staticmethod + def _compute_preview_size(text_len: int) -> int: + """Compute adaptive preview window size for LLM structure analysis. + + Scales with document length: at least *_TREE_PREVIEW_MIN* chars, + up to *_TREE_PREVIEW_MAX*, using *_TREE_PREVIEW_RATIO* of the + document length as the baseline. + """ + return max( + _TREE_PREVIEW_MIN, + min(int(text_len * _TREE_PREVIEW_RATIO), _TREE_PREVIEW_MAX), + ) + + @staticmethod + def _count_nodes(node: TreeNode) -> int: + return 1 + sum(DocumentTreeIndexer._count_nodes(c) for c in node.children) + + @staticmethod + def _max_node_depth(node: TreeNode) -> int: + if not node.children: + return node.level + return max(DocumentTreeIndexer._max_node_depth(c) for c in node.children) + + @staticmethod + def _format_page_range( + page_range: "Optional[Tuple[int, int]]", + ) -> str: + """Format a page_range tuple into a human-readable string for prompts.""" + if not page_range: + return "" + ps, pe = page_range + return f" [pages {ps}-{pe}]" if ps != pe else f" [page {ps}]" + + # ------------------------------------------------------------------ # + # Leaf deepening & summary enrichment # + # ------------------------------------------------------------------ # + + async def _deepen_large_leaves( + self, + node: TreeNode, + content: str, + *, + max_leaf_chars: int = 5000, + max_depth: int = 4, + _seen_ids: Optional[set] = None, + ) -> None: + """Recursively deepen leaf nodes whose char_range exceeds *max_leaf_chars* using LLM decomposition.""" + if _seen_ids is None: + _seen_ids = self._collect_node_ids(node) + + if not node.leaf: + for child in node.children: + await self._deepen_large_leaves( + child, content, + max_leaf_chars=max_leaf_chars, + max_depth=max_depth, + _seen_ids=_seen_ids, + ) + return + + start, end = node.char_range + span = end - start + if span <= max_leaf_chars or node.level >= max_depth: + return + + snippet = self._truncate_snippet(content[start:end]) + + prompt = ( + "Analyze this document section and identify 3-8 logical sub-sections.\n" + "For each sub-section, provide:\n" + '- "title": descriptive heading (concise)\n' + '- "start_text": the first 8-15 words that mark where this sub-section ' + "begins (must be exact text from the content)\n" + '- "content_type": "text" or "table"\n\n' + f'Section: "{node.title}"\n---\n{snippet}\n---\n\n' + 'Return ONLY a JSON array, e.g. ' + '[{"title": "...", "start_text": "...", "content_type": "text"}, ...]' + ) + + try: + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + sub_sections = self._parse_json_array(resp.content) + if not sub_sections or len(sub_sections) < 2: + return + except Exception: + return + + sub_nodes = self._build_sub_nodes_from_llm( + sub_sections, node, content, _seen_ids, + ) + if not sub_nodes: + return + + node.children = sub_nodes + await self._log.info( + f"[TreeIndexer] Deepened '{node.title}' into {len(sub_nodes)} sub-nodes" + ) + + # Recurse into newly created children + for child in node.children: + await self._deepen_large_leaves( + child, content, + max_leaf_chars=max_leaf_chars, + max_depth=max_depth, + _seen_ids=_seen_ids, + ) + + def _build_sub_nodes_from_llm( + self, + sub_sections: List[Dict[str, Any]], + parent: TreeNode, + content: str, + seen_ids: set, + ) -> List[TreeNode]: + """Create child TreeNodes from LLM-decomposed sub-sections.""" + parent_start, parent_end = parent.char_range + parent_span = max(parent_end - parent_start, 1) + parent_ps, parent_pe = parent.page_range if parent.page_range else (0, 0) + page_span = parent_pe - parent_ps + child_level = parent.level + 1 + + # Resolve char_start for each sub-section + positions: List[int] = [] + search_from = parent_start + for sec in sub_sections: + start_text = sec.get("start_text", "") + pos = content.find(start_text, search_from) if start_text else -1 + if pos < 0 or pos >= parent_end: + pos = search_from + positions.append(pos) + search_from = pos + 1 + + nodes: List[TreeNode] = [] + for i, sec in enumerate(sub_sections): + char_start = positions[i] + char_end = positions[i + 1] if i + 1 < len(positions) else parent_end + + # Estimate page_range proportionally from parent + page_range = None + if parent.page_range and parent_span > 0: + p_start = parent_ps + (char_start - parent_start) / parent_span * page_span + p_end = parent_ps + (char_end - parent_start) / parent_span * page_span + page_range = (int(p_start), max(int(p_start), int(p_end))) + + content_type = sec.get("content_type", "text") + if content_type not in ("text", "table"): + content_type = "text" + + nodes.append(TreeNode( + node_id=self._unique_node_id(char_start, seen_ids), + title=sec.get("title", f"Sub-section {i + 1}"), + summary="", + char_range=(char_start, char_end), + level=child_level, + page_range=page_range, + content_type=content_type, + )) + return nodes + + async def _enrich_node_summaries( + self, + node: TreeNode, + content: str, + *, + max_summary_len: int = 200, + ) -> None: + """Post-order traversal to enrich empty summaries: leaf from content, non-leaf via LLM.""" + # Post-order: process children first + for child in node.children: + await self._enrich_node_summaries( + child, content, max_summary_len=max_summary_len, + ) + + if self._summary_needs_enrichment(node.summary): + if node.leaf: + node.summary = self._extract_leaf_summary( + content, node.char_range, max_summary_len, + ) + else: + node.summary = await self._generate_nonleaf_summary( + node, max_summary_len, + ) + + @staticmethod + def _summary_needs_enrichment(summary: str) -> bool: + """Check whether a summary is empty or too short to be useful.""" + return not summary or len(summary.strip()) < 10 + + @staticmethod + def _extract_leaf_summary( + content: str, + char_range: Tuple[int, int], + max_len: int, + ) -> str: + """Extract a concise summary for a leaf node from its content slice.""" + start, end = char_range + raw = content[start:end][:500] + # Clean to single line + return " ".join(raw.split())[:max_len] + + async def _generate_nonleaf_summary( + self, + node: TreeNode, + max_summary_len: int, + ) -> str: + """Generate a summary for a non-leaf node via LLM, with fallback.""" + children_listing = "\n".join( + f"- {c.title}: {c.summary[:100]}" for c in node.children + ) + prompt = ( + "Summarize this document section in 1-2 concise sentences.\n" + f'Section: "{node.title}"\n' + f"Sub-sections:\n{children_listing}\n\n" + "Return ONLY the summary text." + ) + try: + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + return resp.content.strip()[:max_summary_len] + except Exception: + # Fallback: concatenate children titles + return ", ".join(c.title for c in node.children)[:max_summary_len] + + # ------------------------------------------------------------------ # + # Parsing / snippet helpers # + # ------------------------------------------------------------------ # + + @staticmethod + def _truncate_snippet( + text: str, + *, + head_chars: int = 3000, + tail_chars: int = 1000, + ) -> str: + """Truncate a long text snippet keeping head and tail with an ellipsis marker.""" + if len(text) <= head_chars + tail_chars: + return text + return text[:head_chars] + "\n...[truncated]...\n" + text[-tail_chars:] + + @staticmethod + def _parse_json_array(raw: str) -> List[Dict[str, Any]]: + """Extract and parse a JSON array from LLM output.""" + cleaned = re.sub(r"^```(?:json)?\s*", "", raw, flags=re.MULTILINE) + cleaned = re.sub(r"```\s*$", "", cleaned, flags=re.MULTILINE).strip() + m = re.search(r"\[.*\]", cleaned, re.DOTALL) + if m: + return json.loads(m.group()) + return [] + + @staticmethod + def _collect_node_ids(node: TreeNode) -> set: + """Collect all existing node_ids in the subtree.""" + ids = {node.node_id} + for c in node.children: + ids.update(DocumentTreeIndexer._collect_node_ids(c)) + return ids + + @staticmethod + def should_build_tree(file_path: str, content_length: int) -> bool: + """Determine whether a file is eligible for tree indexing.""" + ext = Path(file_path).suffix.lower() + return ext in _TREE_EXTENSIONS and content_length >= _TREE_MIN_CHARS + + # ------------------------------------------------------------------ # + # Hierarchy inference for flat TOC entries # + # ------------------------------------------------------------------ # + + # Minimum number of TOC entries to trigger hierarchy inference. + # Documents with fewer entries are typically already well-structured. + _FLAT_ENTRY_THRESHOLD = 20 + + # If this fraction of entries share the same level, consider it "flat" + # and apply hierarchy inference. Real hierarchies typically have + # varied level distribution. + _FLAT_LEVEL_RATIO = 0.9 + + # Number of entries per virtual group when using uniform grouping fallback. + _GROUP_SIZE = 15 + + @staticmethod + def _infer_hierarchy(entries: List[Any]) -> List[Any]: + """When all entries share the same level, infer hierarchy from title patterns. + + Applies three strategies in priority order: + A. Keyword groups — detect repeated structural prefixes (generic) + B. Generic numbering patterns (1., 1.1, I., A., etc.) + C. Uniform grouping fallback (virtual parent nodes) + + Only activates when >90% of entries share the same level and + the total count exceeds ``_FLAT_ENTRY_THRESHOLD``. + + Args: + entries: List of TOCEntry (may be nested). + + Returns: + Possibly restructured list of TOCEntry with updated levels + and rebuilt hierarchy. + """ + if not entries: + return entries or [] + + try: + from sirchmunk.learnings.toc_extractor import TOCExtractor + flat: List[Any] = [] + TOCExtractor._flatten_entries(entries, flat) + except Exception: + return entries # Cannot flatten; return original entries + + if not flat: + return entries + + if len(flat) <= DocumentTreeIndexer._FLAT_ENTRY_THRESHOLD: + return entries + + # Validate level field: skip entries with invalid levels + valid_flat = [e for e in flat if hasattr(e, 'level') and isinstance(e.level, (int, float))] + if not valid_flat: + return entries + + # Check if >90% share the same level + level_counts = Counter(e.level for e in valid_flat) + dominant_level, dominant_count = level_counts.most_common(1)[0] + if dominant_count / len(flat) <= DocumentTreeIndexer._FLAT_LEVEL_RATIO: + return entries # Already has meaningful hierarchy + + # Try strategies in priority order + modified = DocumentTreeIndexer._strategy_keyword_groups(flat, dominant_level) + if modified is None: + modified = DocumentTreeIndexer._strategy_numbering(flat, dominant_level) + if modified is None: + modified = DocumentTreeIndexer._strategy_uniform_grouping( + flat, dominant_level, + ) + if modified is None: + return entries + + # Rebuild hierarchy from the re-leveled flat list + return TOCExtractor._build_hierarchy(modified) + + # -- Strategy A: keyword groups (generic structural prefix detection) # + + # Pattern: title starts with a capitalized word optionally followed by + # a Roman numeral or Arabic number (e.g. "PART IV", "Item 1A", + # "Section 3", "Chapter 12", "Article II"). + _RE_STRUCTURAL_PREFIX = re.compile( + r'^([A-Z][A-Za-z]*(?:\s+[IVXLCDM\d]+[A-Za-z]?)?)\b', + ) + + @staticmethod + def _extract_structural_prefix(title: str) -> Optional[str]: + """Extract a structural prefix from a title. + + Matches leading capitalized words optionally followed by a number + or Roman numeral (e.g. "PART IV", "Item 1A", "Section 3"). + Returns the normalized (uppercased) prefix, or None. + """ + if not title or not title.strip(): + return None + m = DocumentTreeIndexer._RE_STRUCTURAL_PREFIX.match(title.strip()) + if m: + prefix = m.group(1).strip() + # Prefix must not be too long (avoid capturing entire title) + if len(prefix) <= 20: + return prefix.upper() + return None + + @staticmethod + def _strategy_keyword_groups( + flat: List[Any], + dominant_level: int, + ) -> Optional[List[Any]]: + """Strategy A — detect repeated structural prefixes and infer levels. + + Works for any document with repetitive heading patterns (SEC filings, + legal contracts, technical specs, etc.). Automatically discovers + prefix groups and assigns hierarchical levels based on frequency: + lower-frequency prefixes become higher-level parents. + + Returns re-leveled flat list, or None if coverage is insufficient. + """ + # 1. Extract prefix for each entry + prefix_map: Dict[str, List[int]] = {} # prefix -> [entry indices] + for i, e in enumerate(flat): + prefix = DocumentTreeIndexer._extract_structural_prefix(e.title) + if prefix: + prefix_map.setdefault(prefix, []).append(i) + + # 2. Keep only prefixes appearing >= 2 times + repeated_prefixes = {k: v for k, v in prefix_map.items() if len(v) >= 2} + if not repeated_prefixes: + return None + + # 3. Check coverage: at least 30% of entries must be covered + covered = sum(len(indices) for indices in repeated_prefixes.values()) + if covered < len(flat) * 0.3: + return None + + # 4. Sort prefixes by frequency (ascending) then by first appearance + # Low frequency = higher level (parent), high frequency = lower level + sorted_prefixes = sorted( + repeated_prefixes.items(), + key=lambda x: (len(x[1]), min(x[1])), + ) + + # 5. Assign level per prefix group + prefix_to_level: Dict[str, int] = {} + for level_idx, (prefix, _) in enumerate(sorted_prefixes): + prefix_to_level[prefix] = level_idx + 1 + + # 6. Determine the "other" level for entries without a known prefix + max_level = max(prefix_to_level.values()) + 1 + + # 7. Apply levels + for i, e in enumerate(flat): + prefix = DocumentTreeIndexer._extract_structural_prefix(e.title) + if prefix and prefix in prefix_to_level: + e.level = prefix_to_level[prefix] + else: + e.level = max_level + e.children = [] + + return flat + + # -- Strategy B: generic numbering --------------------------------- # + + # Three-level numbering: 1.1.1, (a), (i), (1) + _RE_NUM_LEVEL3 = re.compile( + r"^\s*(?:\d+\.\d+\.\d+|\([a-z]\)|\([ivx]+\)|\(\d+\))\s", + re.IGNORECASE, + ) + # Two-level numbering: 1.1, A., B., a., b. + _RE_NUM_LEVEL2 = re.compile( + r"^\s*(?:\d+\.\d+(?!\.)\b|[A-Z]\.\s|[a-z]\.\s)", + ) + # Top-level numbering: 1., 2., I., II. + _RE_NUM_LEVEL1 = re.compile( + r"^\s*(?:\d+\.\s|[IVXLC]+\.\s)", + ) + + @staticmethod + def _strategy_numbering( + flat: List[Any], + dominant_level: int, + ) -> Optional[List[Any]]: + """Strategy B — detect generic numbering patterns. + + Returns re-leveled flat list, or None if fewer than 30% of + entries match any numbering pattern. + """ + matched = 0 + assignments: List[Optional[int]] = [] + + for e in flat: + title = e.title + if DocumentTreeIndexer._RE_NUM_LEVEL3.match(title): + assignments.append(3) + matched += 1 + elif DocumentTreeIndexer._RE_NUM_LEVEL2.match(title): + assignments.append(2) + matched += 1 + elif DocumentTreeIndexer._RE_NUM_LEVEL1.match(title): + assignments.append(1) + matched += 1 + else: + assignments.append(None) + + if matched < len(flat) * 0.3: + return None + + # Apply assignments; entries without a pattern get the level of + # the previous entry + 1 (capped at 3) + prev_level = 1 + for i, e in enumerate(flat): + if assignments[i] is not None: + e.level = assignments[i] + else: + e.level = min(prev_level + 1, 3) + prev_level = e.level + e.children = [] + return flat + + # -- Strategy C: uniform grouping fallback ------------------------- # + + @staticmethod + def _strategy_uniform_grouping( + flat: List[Any], + dominant_level: int, + ) -> Optional[List[Any]]: + """Strategy C — group entries into fixed-size buckets with virtual parents. + + Creates synthetic parent TOCEntry nodes whose char_start/char_end + and page_start/page_end are derived from the first and last child + in each group. + + Returns the re-leveled flat list including virtual parents, or None + on error. + """ + from sirchmunk.learnings.toc_extractor import TOCEntry + + group_size = DocumentTreeIndexer._GROUP_SIZE + num_groups = math.ceil(len(flat) / group_size) + if num_groups <= 1: + return None # Grouping would not improve anything + + parent_level = max(1, dominant_level - 1) if dominant_level > 1 else 1 + child_level = parent_level + 1 + + result: List[Any] = [] + for g in range(num_groups): + start_idx = g * group_size + end_idx = min((g + 1) * group_size, len(flat)) + group = flat[start_idx:end_idx] + + first = group[0] + last = group[-1] + + # Derive positions from children + char_start = first.char_start + char_end = last.char_end if last.char_end else None + page_start = first.page_start + page_end = last.page_start # Best available estimate + + virtual_parent = TOCEntry( + title=f"{first.title} \u2013 {last.title}", + level=parent_level, + char_start=char_start, + char_end=char_end, + page_start=page_start, + page_end=page_end, + children=[], + source="inferred", + ) + result.append(virtual_parent) + + # Set child level + for e in group: + e.level = child_level + e.children = [] + result.extend(group) + + return result diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 1a07e64..e56002c 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -99,8 +99,8 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: # Define granularity characteristics if i == 1: granularity = "Coarse-grained" - desc_text = "Multi-word phrases, compound expressions, broader concepts" - examples = '"machine learning algorithms", "data processing pipeline", "neural network training"' + desc_text = "Multi-word phrases (2-3 words) that are likely to appear **verbatim** in the target document. Prioritize standard domain terminology (e.g. financial statement headings, technical section titles)" + examples = '"capital expenditure", "net income", "accounts payable", "operating cash flow", "total revenue"' elif i == num_levels: granularity = "Fine-grained" desc_text = "Single words, precise terms, atomic concepts" @@ -189,7 +189,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. Is the content meaningful and not just error messages or "no information found"? 3. Are there sufficient evidences and context to answer the user's query? -- : output "true" only if the evidence is sufficient to answer the query. +- : output "true" if the evidence contains relevant information that can help answer the query, even if it requires reasoning, computation, or interpretation. Only output "false" if the evidence is clearly irrelevant or contains no useful information for the query. - : output "true" only if the evidence is sufficient AND the result is worth caching. - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". @@ -389,6 +389,31 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: """ +FAST_QUERY_ANALYSIS_WITH_CATALOG = """Classify the user query, extract search terms, AND select the most relevant document(s) from the compiled index. + +### User Query +{user_input} + +### Compiled Document Index +{document_listing} + +### Output +Return JSON only, no extra text: +{{"type": "search", "primary": ["compound phrase"], "fallback": ["term1", "term2"], "idf": {{"compound phrase": 8.0, "term1": 2.5}}, "primary_alt": [], "fallback_alt": [], "file_hints": [], "intent": "...", "selected_docs": [0, 2], "doc_confidence": "high"}} + +Rules: +- **type**: "search" if the query requires retrieving information from files or documents; "chat" if it is a greeting, small talk, or conversational message — set primary/fallback to empty arrays, put a brief reply in "response". "summary" if the user wants to summarize entire documents. +- **primary**: 1 compound phrase (2-3 words) most likely to appear **verbatim** in the target document. +- **fallback**: 1-3 single-word atomic terms. Tried only if primary misses. +- **primary_alt / fallback_alt**: Cross-lingual equivalents (Chinese↔English). Only the most critical 1-2 terms. +- **file_hints**: filename fragments or glob patterns ONLY if clearly implied; empty array otherwise. +- **intent**: one sentence describing the query intent. +- **idf**: IDF weight (1.0-10.0) for EVERY keyword. Higher for rare terms. +- **selected_docs**: Index numbers (from the Compiled Document Index above) of the 1-3 most relevant documents for this query. Consider BOTH the filename and the summary. Choose documents whose content is most likely to answer the query. +- **doc_confidence**: "high" if you are very confident the selected documents contain the answer; "medium" if likely but uncertain; "low" if guessing. +""" + + ROI_RESULT_SUMMARY = """ ### Task Analyze the provided {text_content} and generate a concise summary in the form of a Markdown Briefing. @@ -397,6 +422,10 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 1. **Language Continuity**: The output must be in the SAME language as the User Input. 2. **Format**: Use Markdown (headings, bullet points, and bold text) for high readability. 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. +4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. +5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. +6. **Rounding**: When converting units (thousands → millions, millions → billions), round to the nearest whole number in the target unit if result ≥10; use 2 decimal places if result <10. Examples: $5,466,312 thousands → "$5,466 million"; $389 million → "$0.39 billion". Percentages: round to 1 decimal place. When the query specifies a rounding rule, follow it exactly. +7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. ### Input Data - **User Input**: {user_input} @@ -412,14 +441,494 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. Is the content meaningful and not just error messages or "no information found"? 3. Are there sufficient evidences and context to answer the user's query? -- : output "true" only if the evidence is sufficient to answer the query. +- : output "true" if the evidence contains relevant information that can help answer the query, even if it requires reasoning, computation, or interpretation. Only output "false" if the evidence is clearly irrelevant or contains no useful information for the query. - : output "true" only if the evidence is sufficient AND the result is worth caching. - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". ### Output Format -[Generate the Markdown Briefing here] +[Generate the Markdown Briefing here with detailed analysis, supporting evidence, and full calculation steps. Complete all reasoning BEFORE the PRECISE_ANSWER tag.] + + +[State ONLY the final verified answer. CRITICAL: For yes/no questions, the FIRST word MUST be "Yes" or "No". For identification questions ("What is the largest…?"), state the name/label. For value questions, state the number with units (e.g. "$1,832 million", "39.7%"). For calculations, this MUST reflect the result from your completed computation above. If the query is open-ended, write a one-sentence conclusion.] + +true/false +true/false +""" + +ROI_RESULT_SUMMARY_WITH_CONTEXT = """ +### Task +Analyze the provided evidence and generate a concise summary in the form of a Markdown Briefing. +Leverage the document context below for better understanding of the source material's structure and purpose. + +### Constraints +1. **Language Continuity**: The output must be in the SAME language as the User Input. +2. **Format**: Use Markdown (headings, bullet points, and bold text) for high readability. +3. **Style**: Keep it professional, objective, and clear. Avoid fluff. +4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. +5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. +6. **Rounding**: When converting units (thousands → millions, millions → billions), round to the nearest whole number in the target unit if result ≥10; use 2 decimal places if result <10. Examples: $5,466,312 thousands → "$5,466 million"; $389 million → "$0.39 billion". Percentages: round to 1 decimal place. When the query specifies a rounding rule, follow it exactly. +7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. + +### Document Context +{document_context} + +### Input Data +- **User Input**: {user_input} +- **Search Result Text**: {text_content} + +### Quality Evaluation +After generating the summary, make TWO decisions: +1) whether the query can be answered from the provided evidence; +2) whether this result is worth caching. + +Evaluate based on: +1. Does the search result contain substantial, relevant information for the user input? +2. Is the content meaningful and not just error messages or "no information found"? +3. Are there sufficient evidences and context to answer the user's query? + +- : output "true" if the evidence contains relevant information that can help answer the query, even if it requires reasoning, computation, or interpretation. Only output "false" if the evidence is clearly irrelevant or contains no useful information for the query. +- : output "true" only if the evidence is sufficient AND the result is worth caching. +- If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". + +### Output Format + +[Generate the Markdown Briefing here with detailed analysis, supporting evidence, and full calculation steps. Complete all reasoning BEFORE the PRECISE_ANSWER tag.] + + +[State ONLY the final verified answer. CRITICAL: For yes/no questions, the FIRST word MUST be "Yes" or "No". For identification questions ("What is the largest…?"), state the name/label. For value questions, state the number with units (e.g. "$1,832 million", "39.7%"). For calculations, this MUST reflect the result from your completed computation above. If the query is open-ended, write a one-sentence conclusion.] + +true/false +true/false +""" + + +# --------------------------------------------------------------------------- +# Deep Structured Reasoning prompts +# --------------------------------------------------------------------------- + +DEEP_SECTION_SELECT = """Given the user query and a document section map, select the sections most likely to contain the answer. + +### User Query +{query} + +### Document Section Map +{section_map} + +### Instructions +1. Identify which sections contain data needed to answer the query. +2. For questions requiring computation (ratios, growth rates, comparisons), select ALL sections containing the required input data — even if you think some may be redundant. +3. Prefer sections containing structured data (tables, financial statements) over narrative sections. +4. For financial/annual report queries, ALWAYS include sections matching these types when available: + - Income Statement / Consolidated Statements of Operations (revenue, expenses, net income) + - Balance Sheet / Consolidated Balance Sheets (assets, liabilities, equity) + - Cash Flow Statement / Consolidated Statements of Cash Flows (capex, operating cash flow) + - Notes to Financial Statements (breakdowns, segment data, detailed schedules) + - Management's Discussion and Analysis (context, trends, explanations) +5. Select 2-6 sections. When in doubt, select MORE rather than fewer — missing data causes answer failure. + +### Output +Return ONLY a JSON array of section indices (0-based) from the map above: +[0, 3, 5] +""" + + +# --------------------------------------------------------------------------- +# DEEP mode query classification (Plan B) +# --------------------------------------------------------------------------- + +DEEP_QUERY_CLASSIFY = """Classify this search query along two dimensions. + +Query: {query} + +1. **Complexity** — how many reasoning steps are needed: + - "simple": Direct lookup of a single value (e.g. "What was revenue in FY2023?") + - "moderate": Requires light computation from 1-2 data points (e.g. "What was the gross margin?") + - "complex": Multi-step computation, multi-period comparison, or cross-entity analysis + +2. **Intent** — what the user needs: + - "lookup": Find and extract a specific stated value + - "computation": Calculate a derived metric (ratio, growth rate, difference, average) + - "comparison": Compare values across time periods, segments, or companies + +Return ONLY valid JSON on a single line: +{{"complexity": "simple", "intent": "lookup"}} +""" + +# --------------------------------------------------------------------------- +# Intent-specific synthesis prompts (Plan C) +# --------------------------------------------------------------------------- + +ROI_LOOKUP_SYNTHESIS = """### Task +Extract the specific value requested from the evidence and present it clearly. + +### Constraints +1. **Language Continuity**: The output must be in the SAME language as the User Input. +2. Find the value stated in the evidence. If the exact total is not stated but its components are clearly present, compute it by summing the components. +3. **Rounding**: When converting units (e.g., thousands → millions), round to the nearest whole number in the target unit IF the result is ≥10. If the result is <10, use 2 decimal places. Examples: $5,466,312 thousands → "$5,466 million"; $302,578 thousands → "$303 million"; $389 million → "$0.39 billion". Percentages: round to 1 decimal place. When the query specifies a rounding rule, follow it exactly. +4. If multiple candidate values exist, select based on the closest match to the query's time period, entity, and metric. +5. Quote the source passage containing the value. +6. Only mark SHOULD_ANSWER as "false" when no relevant data exists in the evidence. Always prefer attempting an answer over refusing. +7. When the evidence contains relevant data but you feel uncertain, still attempt to answer. +8. **Answer format**: + - For yes/no questions (e.g., "Has X increased?", "Did the company…?", "Does X maintain…?", "Is X healthy?"), PRECISE_ANSWER **MUST** begin with "Yes" or "No" as the very first word. Then provide a brief qualifier. + - For identification questions (e.g., "What is the largest segment?", "Which company had the highest…?"), PRECISE_ANSWER should state the name/label, not the numeric value. + - For value questions (e.g., "What was total revenue?"), PRECISE_ANSWER should state the numeric value with units. + - When asked about the "nature", "purpose", "composition", or "breakdown" of something, describe what it IS and its proportional components (e.g., "87% relates to employee liabilities"), not just the total dollar amount. + - When listing items (e.g., "Which securities are registered?"), provide the COMPLETE list from the evidence, not just one example. + +### Input Data +- **User Input**: {user_input} +- **Evidence**: {text_content} + +### Output Format + +**Source passage**: [Quote the exact text containing the answer] + +**Extracted value**: [The specific value found] +[value only, e.g. "$1,832 million", "Yes, it increased by 5%", "Cloud Services segment"] true/false true/false """ + +ROI_COMPUTATION_SYNTHESIS = """### Task +Answer the query by extracting data from the evidence and performing the required calculation. + +### Constraints +1. **Language Continuity**: The output must be in the SAME language as the User Input. +2. Follow this STRICT sequence — do NOT skip any step: + a) **DATA EXTRACTION**: List each required data point with its exact value and where you found it. + b) **FORMULA**: State the formula needed (e.g. Gross Margin = (Revenue - COGS) / Revenue). + c) **SUBSTITUTION**: Plug in the extracted values into the formula. + d) **CALCULATION**: Show arithmetic step by step. For each step, write the operation and its result. + e) **VERIFICATION**: Re-compute the final result independently to confirm. +3. **Rounding**: + - Dollar amounts: when converting units, round to the nearest whole number in the target unit IF the result is ≥10. If the result is <10 in the target unit, use 2 decimal places. Examples: $381,603 thousands → "$382 million"; $5,466,312 thousands → "$5,466 million"; $389 million → "$0.39 billion". + - Percentages: round to 1 decimal place. + - Ratios: round to 2 decimal places. + - Per-share values: round to 2 decimal places. + - When the query specifies "round to X decimal places", follow that exactly. +4. **Units**: Convert all values to consistent units before computing. +5. If any required data point is missing, explicitly state what is missing and mark SHOULD_ANSWER as "false". +6. **Financial ratio definitions**: + - **Quick ratio** = (Cash and Cash Equivalents + Short-term Investments + Net Receivables) / Total Current Liabilities. Do NOT include inventories, prepaid expenses, or other current assets in the numerator. + - **Interest coverage ratio** = EBIT / Interest Expense. If EBIT is negative, the coverage ratio is zero (or negative) — a company cannot service debt from negative earnings. + - **Asset turnover** = Revenue / Average Total Assets. + - A quick ratio below 1.0x generally indicates the company does NOT have a reasonably healthy liquidity position. +7. **Answer format**: + - For yes/no questions (e.g., "Does X have healthy liquidity?", "Has X improved?", "Does X maintain…?"), PRECISE_ANSWER **MUST** begin with "Yes" or "No" as the very first word. + - For identification questions, state the name/label, not just the number. + - When asked about "nature", "purpose", or "composition", describe qualitative aspects and proportions, not just total amounts. + +### Input Data +- **User Input**: {user_input} +- **Evidence**: {text_content} + +### Output Format + +## Data Extraction +| Data Point | Value | Source | +|---|---|---| +| [name] | [exact value] | [where found in evidence] | + +## Calculation +**Formula**: [state formula] +**Step 1**: [operation] = [result] +**Step 2**: [operation] = [result] +**Verification**: [re-compute to confirm] + +[final computed value only] +true/false +true/false +""" + +ROI_COMPARISON_SYNTHESIS = """### Task +Compare the requested values across the specified dimensions (time periods, entities, or segments). + +### Constraints +1. **Language Continuity**: The output must be in the SAME language as the User Input. +2. Extract values for EACH comparison dimension from the evidence. +3. Present in a structured comparison table. +4. State the direction and magnitude of difference or change. +5. **Rounding**: When computing changes or growth rates, round percentages to 1 decimal place. When converting units (e.g., thousands → millions), round to nearest whole number in target unit if result ≥10; otherwise use 2 decimal places. +6. **"Best performing"** means highest growth rate or change rate, not highest absolute value, unless the query explicitly says "largest" or "highest revenue". +7. If values for any comparison dimension are missing, state what is missing. +8. **Answer format**: For yes/no questions ("Has X improved?", "Was there any change?"), PRECISE_ANSWER **MUST** begin with "Yes" or "No" as the very first word, followed by the comparison details. + +### Input Data +- **User Input**: {user_input} +- **Evidence**: {text_content} + +### Output Format + +## Comparison +| Dimension | Value | Source | +|---|---|---| +| [period/entity] | [value] | [where found] | + +## Analysis +**Direction**: [increased/decreased/stable] +**Magnitude**: [absolute and/or percentage change, with arithmetic shown] + +[concise comparison result, e.g. "Increased from $1.2B to $1.5B (25% growth)"] +true/false +true/false +""" + +# --------------------------------------------------------------------------- +# Evidence completeness check (Plan D) +# --------------------------------------------------------------------------- + +EVIDENCE_COMPLETENESS_CHECK = """Given the query and available evidence, determine whether all data points needed to answer are present. + +### Query +{query} + +### Query Type +{intent} + +### Evidence (excerpt) +{evidence_excerpt} + +### Instructions +1. Identify the specific data points required to answer this query. +2. Check whether each data point's actual value appears in the evidence. +3. A data point is FOUND only if its numeric/factual value is explicitly stated. + +Return ONLY valid JSON on a single line: +{{"complete": true, "missing": []}} +or +{{"complete": false, "missing": ["short description of what is missing"]}} +""" + +# --------------------------------------------------------------------------- +# Computation correction (Plan E) +# --------------------------------------------------------------------------- + +COMPUTATION_CORRECTION = """Your previous calculation contained an arithmetic error. Please revise. + +### Query +{query} + +### Your Previous Answer +{original_answer} + +### Detected Error +- Expression: {expression} +- Your result: {llm_result} +- Correct result: {correct_result} + +Revise your answer using the correct arithmetic. Keep the same analysis structure. + + +[Corrected analysis with fixed calculation] + +[Corrected final value] +true +true +""" + +# --------------------------------------------------------------------------- +# Agentic retrieval prompts (DEEP mode) +# --------------------------------------------------------------------------- + +DEEP_DATA_REQUIREMENTS = """Given the user's question, identify the specific data points needed to answer it. + +### Question +{query} + +### Question Type +{intent} + +### Instructions +1. List each specific data point needed to answer this question (e.g., "Total Revenue for FY2022", "Accounts Payable as of fiscal year end 2019"). +2. For each data point, identify the likely document section type where it would appear (e.g., "Income Statement", "Balance Sheet", "Cash Flow Statement", "Notes to Financial Statements", "Management Discussion and Analysis", "Segment Information"). +3. If a calculation is required, state the exact formula with explicit variable names matching how they typically appear in financial statements. If the question provides its own formula definition, use THAT formula exactly. Otherwise use these standard definitions: + - Quick Ratio = (Cash and Cash Equivalents + Short-term Investments + Net Receivables) / Total Current Liabilities + - Interest Coverage Ratio = EBIT / Interest Expense (if EBIT is negative, ratio = 0) + - Asset Turnover = Revenue / Average Total Assets + - Net Profit Margin = Net Income / Total Revenue +4. Identify the time period(s) required. +5. For comparison or identification questions (e.g., "What is the largest segment?", "Which year had the highest growth?"), note what dimensions need comparison. + +Return ONLY valid JSON on a single line: +{{"data_points": ["data point 1", "data point 2"], "likely_sources": ["section type 1", "section type 2"], "formula": "explicit formula with variable names, or null", "time_period": "period or null"}} +""" + +DEEP_PAGE_SELECT = """You are locating specific data in a document. Select pages to fetch. + +### Question +{query} + +### Data Still Needed +{data_requirements} + +### Document Outline (with page ranges) +{section_map} + +### Pages Already Fetched +{fetched_pages} + +### Evidence Already Gathered +{evidence_summary} + +### Instructions +- Reason about which sections contain the needed data based on section titles, summaries, and page ranges. +- Consider what data has already been gathered to avoid fetching redundant content. +- Financial statements (Income Statement, Balance Sheet, Cash Flow Statement) typically contain quantitative data needed for calculations. +- Sections with tables are often high-value for data extraction. +- Do NOT re-select pages listed in "Pages Already Fetched". +- Select 3-8 pages that are most likely to contain the missing data. +- When uncertain, prefer sections deeper in the document (financial statements are usually after narrative sections). + +Return ONLY a JSON array of page numbers to fetch: [45, 46, 52, 53] +""" + +DEEP_CHECK_REQUIREMENTS = """Check whether the evidence contains all required data points. + +### Question +{query} + +### Required Data Points +{data_points} + +### Formula (if applicable) +{formula} + +### Evidence +{evidence} + +### Instructions +For each required data point, check if its actual numeric or factual value appears in the evidence. A data point is FOUND only if you can identify its specific value in the text. + +Return ONLY valid JSON: +{{"complete": true, "found": [{{"point": "description", "value": "extracted value"}}], "missing": []}} +or +{{"complete": false, "found": [{{"point": "description", "value": "extracted value"}}], "missing": ["description of missing data point"]}} +""" + +DEEP_TOC_ANALYSIS = """Analyze the following pages from the beginning of a document and extract its structural outline. + +### Document Pages +{toc_page_text} + +### Total Document Pages +{total_pages} + +### Instructions +1. Look for a table of contents, section listing, or structural overview. +2. Extract every section entry with its title, starting page number, and hierarchy level. +3. Infer page_end from the start of the next section (use {total_pages} for the last section). +4. If page numbers appear as dot leaders (e.g. "Item 7. MD&A ........ 45"), extract the page number. +5. If no structural information can be extracted, return an empty array. + +Return ONLY valid JSON — an array of section objects: +[{{"title": "Section Title", "page_start": 3, "page_end": 15, "level": 1}}, ...] + +If no structure found, return: [] +""" + +# --------------------------------------------------------------------------- +# Knowledge Compile prompts +# --------------------------------------------------------------------------- + +COMPILE_TREE_STRUCTURE = """Analyze the following document and identify its natural hierarchical structure (chapters, sections, subsections). + +### Document Content (may be truncated) +{document_content} + +### Output Requirements +Return a JSON array of top-level sections. Each section object must have: +- "title": Section heading or descriptive title +- "summary": 1-2 sentence summary of the section content +- "start_marker": A short text string (5-15 words) that appears verbatim at the start of this section in the document +- "end_marker": A short text string that appears at the start of the NEXT section (empty for the last section) + +Maximum {max_sections} sections. Identify only the most significant structural boundaries. + +### Output Format +Return ONLY a JSON array, no extra text: +[ + {{"title": "...", "summary": "...", "start_marker": "...", "end_marker": "..."}}, + ... +] +""" + + +COMPILE_SYNTHESIZE_SUMMARY = """Synthesize a comprehensive document summary from the following section summaries. + +### Section Summaries +{sections} + +### Output +Provide a unified, coherent summary in 3-8 sentences that captures the document's overall topic, key arguments, and conclusions. Do not simply list the sections — weave them into a natural narrative. +Write in the same language as the section summaries.""" + + +COMPILE_DOC_SUMMARY = """Summarize the following document concisely, capturing the key topics, arguments, conclusions, and important details. + +### File: {file_name} + +### Document Content (may be truncated) +{document_content} + +### Output +Provide a comprehensive summary in 3-8 sentences. Focus on: +1. What is this document about (main topic/purpose) +2. Key findings, arguments, or conclusions +3. Important details, data points, or methodologies + +Write the summary in the same language as the document content.""" + + +COMPILE_TOPIC_EXTRACTION = """Extract the 3-5 most important topics, concepts, or entities from the following summary. + +### Summary +{summary} + +### Output +Return ONLY a JSON array of topic strings, no extra text: +["topic1", "topic2", "topic3"] + +Rules: +- Each topic should be 1-4 words +- Prefer specific, domain-relevant terms over generic ones +- Use the same language as the summary""" + + +COMPILE_CLASSIFY_HEADINGS = """Classify each bold text line as either a **section heading** or **non-heading**. + +A line is a *section heading* if it serves as the title of a major structural division of the document (chapter, section, subsection, exhibit, schedule, financial statement, note, etc.). +A line is *non-heading* if it is emphasis text, a label, a caption, a total/subtotal row, or any inline bold phrase that does not introduce a new document section. + +For each heading, also assign a Markdown heading level (2–4): +- Level 2: top-level sections (e.g. financial statements, major chapters) +- Level 3: sub-sections (e.g. notes to financial statements, sub-chapters) +- Level 4: sub-sub-sections + +Return ONLY a JSON array of objects for the lines that ARE headings. +Each object: {{"idx": <0-based index>, "level": <2|3|4>}} +If none are headings, return an empty array: [] + +Bold lines: +{candidates}""" + + +COMPILE_MERGE_KNOWLEDGE = """You are merging new information into an existing knowledge cluster. + +### Existing Knowledge +{existing_content} + +### New Information +{new_summary} + +### Task +Produce an updated, unified summary that: +1. Preserves all important information from the existing knowledge +2. Integrates the new information, avoiding redundancy +3. Highlights any contradictions or complementary perspectives +4. Maintains a coherent, well-structured narrative + +### Output +Return ONLY the merged summary text (no extra tags or metadata). Keep the same language as the inputs.""" diff --git a/src/sirchmunk/schema/knowledge.py b/src/sirchmunk/schema/knowledge.py index 336963d..2a6e149 100644 --- a/src/sirchmunk/schema/knowledge.py +++ b/src/sirchmunk/schema/knowledge.py @@ -57,6 +57,12 @@ class EvidenceUnit: # IDs of conflict group if this evidence contradicts others conflict_group: Optional[List[str]] = None + # Tree-index node path from root to the matched node (e.g. ["N000000", "N001234"]) + tree_path: Optional[List[str]] = None + + # Character range within the document for precise evidence location + page_range: Optional[List[int]] = None + def to_dict(self) -> Dict[str, Any]: """ Serialize EvidenceUnit to a dictionary. @@ -69,6 +75,8 @@ def to_dict(self) -> Dict[str, Any]: "snippets": self.snippets, "extracted_at": self.extracted_at.isoformat(), "conflict_group": self.conflict_group, + "tree_path": self.tree_path, + "page_range": self.page_range, } @@ -234,6 +242,9 @@ class KnowledgeCluster: # Used for semantic similarity matching and cluster reuse queries: List[str] = None + # Number of times this cluster has been merged with new evidence during compile + merge_count: int = 0 + def __post_init__(self): if self.related_clusters is None: self.related_clusters = [] @@ -391,5 +402,6 @@ def to_dict(self) -> Dict[str, Any]: "related_clusters": [rc.to_dict() for rc in self.related_clusters], "search_results": self.search_results, "queries": self.queries, + "merge_count": self.merge_count, } diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index aee1c16..d02c138 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -8,22 +8,32 @@ import os import re import traceback +from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union from sirchmunk.base import BaseSearch from sirchmunk.learnings.knowledge_base import KnowledgeBase +from sirchmunk.utils.document_extractor import DocumentExtractor from sirchmunk.llm.openai_chat import OpenAIChat from sirchmunk.llm.prompts import ( KEYWORD_QUERY_PLACEHOLDER, generate_keyword_extraction_prompt, FAST_QUERY_ANALYSIS, + FAST_QUERY_ANALYSIS_WITH_CATALOG, ROI_RESULT_SUMMARY, - SEARCH_RESULT_SUMMARY, + ROI_LOOKUP_SYNTHESIS, + ROI_COMPUTATION_SYNTHESIS, + ROI_COMPARISON_SYNTHESIS, DOC_SUMMARY, DOC_CHUNK_SUMMARY, DOC_MERGE_SUMMARIES, + DEEP_SECTION_SELECT, + DEEP_DATA_REQUIREMENTS, + DEEP_PAGE_SELECT, + DEEP_CHECK_REQUIREMENTS, + DEEP_TOC_ANALYSIS, ) from sirchmunk.retrieve.text_retriever import GrepRetriever from sirchmunk.schema.knowledge import ( @@ -81,6 +91,174 @@ _NO_RESULTS_MESSAGE = "No results found." +# Soft-similarity threshold for gradient cluster reuse (P2) +_SOFT_SIM_THRESHOLD = 0.65 + + +class _PathScope: + """Immutable search-path scope for filtering compile artifacts. + + Resolves the provided search paths into absolute file paths and + directory prefixes, then offers ``contains()`` to test whether a + given artifact path falls within this scope. + + When the scope is empty (no paths provided), ``contains()`` always + returns True — i.e. *no filtering* is applied. + """ + + __slots__ = ("_files", "_dirs", "_empty") + + def __init__(self, search_paths: Optional[List[str]] = None) -> None: + files: Set[str] = set() + dirs: List[str] = [] + if search_paths: + for p in search_paths: + resolved = str(Path(p).expanduser().resolve()) + if Path(resolved).is_file(): + files.add(resolved) + elif Path(resolved).is_dir(): + dirs.append( + resolved if resolved.endswith(os.sep) + else resolved + os.sep + ) + else: + files.add(resolved) + self._files = frozenset(files) + self._dirs = tuple(dirs) + self._empty = not files and not dirs + + def contains(self, file_path: str) -> bool: + """Return True when *file_path* falls within the search scope.""" + if self._empty: + return True + if not file_path: + return False + resolved = str(Path(file_path).expanduser().resolve()) + if resolved in self._files: + return True + return any(resolved.startswith(d) for d in self._dirs) + + @property + def is_empty(self) -> bool: + return self._empty + +# Pure tree search mode for ablation experiments. +# When enabled, search relies solely on tree index navigation, skipping rga keyword search. +_PURE_TREE_SEARCH: bool = os.getenv("SIRCHMUNK_PURE_TREE_SEARCH", "false").lower() == "true" + +# Common English stop-words filtered out during keyword coverage computation. +_STOP_WORDS: frozenset = frozenset({ + "the", "is", "a", "an", "of", "in", "for", "to", "and", "or", + "what", "how", "which", "does", "was", "were", "has", "have", "had", + "do", "did", "are", "be", "been", "by", "with", "from", "this", + "that", "it", "its", "on", "at", "as", "not", "no", +}) + + +@dataclass +class SoftClusterHit: + """Signals from clusters that are related but below the hard reuse threshold. + + Carries structured hints (keywords, file paths, background context) that + downstream retrieval phases can exploit without short-circuiting the search. + """ + + patterns: List[str] + file_paths: List[str] + context_summary: str + cluster_ids: List[str] + + +@dataclass +class KnowledgeProbeResult: + """Rich result from knowledge cache probing (P3). + + Replaces the flat ``List[str]`` that ``_probe_knowledge_cache`` used to return. + """ + + file_paths: List[str] + extra_keywords: List[str] + background_context: str + + +@dataclass +class CompileHints: + """Zero-LLM hints gathered from compile manifest and tree cache (P4).""" + + file_paths: List[str] + extra_keywords: List[str] + + +@dataclass +class CompileArtifacts: + """Compile artifact availability context for adaptive activation in FAST mode. + + Created once at the start of ``_search_fast()`` via + ``_detect_compile_artifacts()`` and threaded through all pipeline steps. + Each step checks the relevant field and falls back gracefully when the + artifact is absent. + """ + + catalog: Optional[List[Dict[str, str]]] + catalog_map: Dict[str, Dict[str, str]] # path -> catalog entry for O(1) lookup + tree_indexer: Optional[Any] # DocumentTreeIndexer (lazy import) + tree_available_paths: Set[str] # file paths that have cached tree indices + manifest_map: Dict[str, Any] = field(default_factory=dict) # {path: FileManifestEntry} + summary_index: Optional[Any] = None # CompileSummaryIndex (lazy-loaded) + + +@dataclass +class DataRequirements: + """Pre-retrieval analysis of what data points a query needs.""" + + data_points: List[str] + likely_sources: List[str] + formula: Optional[str] + time_period: Optional[str] + intent: str + + +@dataclass +class RetrievalResult: + """Output of the agentic retrieval loop.""" + + evidence: str + pages_extracted: Dict[str, List[int]] + is_complete: bool + rounds_used: int + + +class _TreeNavCache: + """Per-search-session cache for tree navigation results. + + Avoids duplicate LLM navigation calls for the same file+query pair. + Created at the start of each ``_search_fast()`` invocation and reset + per search session. + """ + + __slots__ = ("_store",) + + def __init__(self) -> None: + self._store: Dict[str, Optional[List[Any]]] = {} + + @staticmethod + def _key(file_path: str, query: str) -> str: + import hashlib + return hashlib.md5(f"{file_path}:{query}".encode()).hexdigest() + + def get(self, file_path: str, query: str) -> Optional[List[Any]]: + """Retrieve cached navigation leaves for a file+query pair.""" + key = self._key(file_path, query) + return self._store.get(key) + + def has(self, file_path: str, query: str) -> bool: + """Check whether a cached result exists.""" + return self._key(file_path, query) in self._store + + def put(self, file_path: str, query: str, leaves: Optional[List[Any]]) -> None: + """Store navigation leaves for a file+query pair.""" + self._store[self._key(file_path, query)] = leaves + class AgenticSearch(BaseSearch): @@ -419,6 +597,14 @@ async def _try_reuse_cluster(self, query: str, paths: Optional[List[str]] = None ) return None + # P3: skip clusters whose cached answer is a refusal + if self._is_refusal_answer(content): + await self._logger.info( + f"Cluster {existing_cluster.id} contains a refusal answer, " + "falling back to full search" + ) + return None + # Mutate only after validation passes self._add_query_to_cluster(existing_cluster, query) existing_cluster.hotness = min(1.0, (existing_cluster.hotness or 0.5) + 0.1) @@ -460,6 +646,72 @@ async def _try_reuse_cluster(self, query: str, paths: Optional[List[str]] = None ) return None + async def _try_soft_reuse( + self, query: str, paths: Optional[List[str]] = None, + ) -> Optional[SoftClusterHit]: + """Gradient reuse: extract structured hints from moderately similar clusters. + + Called when ``_try_reuse_cluster`` misses (similarity < hard threshold). + Uses a softer threshold to find clusters that are *related* but not + close enough for full reuse. Returns patterns, file paths, and a + background context summary that downstream phases can exploit. + """ + if not self.embedding_client or not self.embedding_client.is_ready(): + return None + + try: + query_embedding = (await self.embedding_client.embed([query]))[0] + similar = await self.knowledge_storage.search_similar_clusters( + query_embedding=query_embedding, + top_k=5, + similarity_threshold=_SOFT_SIM_THRESHOLD, + search_paths=paths, + ) + if not similar: + return None + + patterns: List[str] = [] + file_paths: List[str] = [] + context_parts: List[str] = [] + cluster_ids: List[str] = [] + seen_paths: set = set() + + for match in similar: + cid = match["id"] + cluster_ids.append(cid) + c = await self.knowledge_storage.get(cid) + if not c: + continue + for p in getattr(c, "patterns", []) or []: + if p and p not in patterns: + patterns.append(p) + for ev in getattr(c, "evidences", []): + fp = str(getattr(ev, "file_or_url", "")) + if fp and fp not in seen_paths and Path(fp).exists(): + seen_paths.add(fp) + file_paths.append(fp) + content = c.content + if isinstance(content, list): + content = "\n".join(content) + if content: + context_parts.append(str(content)[:500]) + + if not patterns and not file_paths: + return None + + await self._logger.info( + f"[SoftReuse] {len(similar)} soft hits: " + f"{len(patterns)} patterns, {len(file_paths)} files" + ) + return SoftClusterHit( + patterns=patterns[:10], + file_paths=file_paths[:10], + context_summary="\n\n".join(context_parts[:3]), + cluster_ids=cluster_ids, + ) + except Exception: + return None + def _add_query_to_cluster(self, cluster: KnowledgeCluster, query: str) -> None: """ Add query to cluster's queries list with FIFO strategy. @@ -478,6 +730,36 @@ def _add_query_to_cluster(self, cluster: KnowledgeCluster, query: str) -> None: # Remove oldest queries (from the beginning) cluster.queries = cluster.queries[-self.max_queries_per_cluster:] + @staticmethod + def _enrich_reused_content(cluster: KnowledgeCluster) -> str: + """Build the answer text from a reused cluster. + + When the cluster carries compiled evidence with non-empty snippets + (populated during ``sirchmunk compile``), appends them as supporting + excerpts so the user sees both the summary and the underlying source + material. + """ + content = cluster.content + if isinstance(content, list): + content = "\n".join(content) + content = str(content or "") + + evidence_parts: List[str] = [] + for ev in getattr(cluster, "evidences", []): + snippets = getattr(ev, "snippets", None) + if not snippets: + continue + source = str(getattr(ev, "file_or_url", "unknown")) + for snip in snippets: + text = snip if isinstance(snip, str) else snip.get("snippet", "") + if text and text.strip(): + evidence_parts.append(f"[{Path(source).name}] {text.strip()}") + + if evidence_parts: + content += "\n\n---\nSupporting evidence:\n" + "\n\n".join(evidence_parts[:5]) + + return content + async def _save_cluster_with_embedding(self, cluster: KnowledgeCluster) -> None: """Save knowledge cluster to persistent storage, compute embedding, and flush to parquet. @@ -734,22 +1016,52 @@ async def _search_by_filename( await self._logger.error(f"Traceback: {traceback.format_exc()}") return [] - @staticmethod - def _parse_summary_response(llm_response: str) -> Tuple[str, bool, bool]: - """ - Parse LLM response to extract summary and quality decisions. + _SELF_CORRECTION_PATTERN = re.compile( + r'(?:correction|re-?verif|wait,?\s|let me re|actually|self-correction|recalcul)', + re.IGNORECASE, + ) - Args: - llm_response: Raw LLM response containing SUMMARY, SHOULD_ANSWER and SHOULD_SAVE tags + _REFUSAL_PATTERN = re.compile( + r'cannot\s+(?:be\s+)?determin' + r'|data\s+(?:not\s+available|insufficient)' + r'|not\s+(?:possible|available)\s+to\s+(?:determin|calculat|answer)' + r'|information\s+(?:is\s+)?not\s+(?:available|provided|found)' + r'|no\s+(?:relevant|sufficient)\s+(?:data|information|evidence)', + re.IGNORECASE, + ) + + @classmethod + def _is_refusal_answer(cls, text: str) -> bool: + """Detect whether *text* is a refusal / no-data answer.""" + if not text or len(text.strip()) < 20: + return True + head = text[:500] + if re.search(r'\bN/?A\b', head): + return True + return bool(cls._REFUSAL_PATTERN.search(head)) + + @classmethod + def _parse_summary_response(cls, llm_response: str) -> Tuple[str, bool, bool]: + """Parse LLM response to extract summary, precise answer, and quality decisions. + + When a ```` tag is present, its content is prepended to + the summary so downstream consumers (evaluation judges, UIs) see the + direct answer prominently without needing separate tag awareness. + + The method also detects self-correction patterns in the summary text: + when the LLM revised its calculation mid-stream, the last numeric + conclusion is used if PRECISE_ANSWER is absent or matches the + pre-correction value. Returns: Tuple of (summary_text, should_save_flag, should_answer_flag) """ summary_fields = extract_fields( content=llm_response, - tags=["SUMMARY", "SHOULD_ANSWER", "SHOULD_SAVE"], + tags=["PRECISE_ANSWER", "SUMMARY", "SHOULD_ANSWER", "SHOULD_SAVE"], ) + precise = str(summary_fields.get("precise_answer") or "").strip() summary = str(summary_fields.get("summary") or "").strip() should_answer_str = str(summary_fields.get("should_answer") or "false").strip().lower() should_save_str = str(summary_fields.get("should_save") or "false").strip().lower() @@ -757,2004 +1069,6676 @@ def _parse_summary_response(llm_response: str) -> Tuple[str, bool, bool]: should_answer = should_answer_str in ["true", "yes", "1"] should_save = should_save_str in ["true", "yes", "1"] - # If extraction failed, use entire response as summary and default to conservative: - # not answerable and not saveable. + if precise and summary: + summary = f"**Answer: {precise}**\n\n{summary}" + elif precise: + summary = precise + if not summary: summary = llm_response.strip() - should_answer = False + # Fallback: detect **Answer: xxx** markdown format used by models + # that ignore / tags (e.g. qwen). + _answer_match = re.search( + r'\*\*Answer:\s*(.+?)\*\*', llm_response, re.DOTALL, + ) + if _answer_match: + _answer_val = _answer_match.group(1).strip() + if _answer_val and not cls._is_refusal_answer(_answer_val): + should_answer = True + should_save = True + if not precise: + precise = _answer_val + else: + should_answer = False + should_save = False + else: + should_answer = False + should_save = False + + # P3: Never persist refusal/no-data answers to cluster cache + if should_save and cls._is_refusal_answer(precise or summary): should_save = False return summary, should_save, should_answer + # ------------------------------------------------------------------ + # Multi-factor evidence acceptance helpers + # ------------------------------------------------------------------ + @staticmethod - def _extract_and_validate_multi_level_keywords( - llm_resp: str, - num_levels: int = 3 - ) -> List[Dict[str, float]]: - """ - Extract and validate multiple sets of keywords from LLM response. + def _compute_keyword_coverage(query: str, evidence: str) -> float: + """Compute the fraction of query keywords found in the evidence text. - Args: - llm_resp: LLM response containing keyword sets - num_levels: Number of keyword granularity levels to extract + Tokenises *query* into lowercase alpha-numeric words (length >= 2), + removes common English stop-words, then checks presence in + lower-cased *evidence*. Returns: - List of keyword dicts, one for each level: [level1_keywords, level2_keywords, ...] + Coverage ratio in [0.0, 1.0]. Returns 0.0 when no valid + keywords can be extracted from *query*. """ - keyword_sets: List[Dict[str, float]] = [] + tokens = re.findall(r'\b[a-z0-9]{2,}\b', query.lower()) + keywords = [t for t in tokens if t not in _STOP_WORDS] + if not keywords: + return 0.0 + evidence_lower = evidence.lower() + matched = sum(1 for kw in keywords if kw in evidence_lower) + return matched / len(keywords) - # Generate tags dynamically based on num_levels - tags = [f"KEYWORDS_LEVEL_{i + 1}" for i in range(num_levels)] + @staticmethod + def _detect_numeric_evidence(query: str, evidence: str) -> bool: + """Detect whether *evidence* contains structured numeric data relevant to *query*. - # Extract all fields at once - extracted_fields = extract_fields(content=llm_resp, tags=tags) + Returns True when *query* implies a numeric/financial intent AND + *evidence* contains numeric patterns (currency amounts, percentages, + financial figures). + """ + query_lower = query.lower() + has_intent = any( + kw in query_lower + for kw in AgenticSearch._NUMERIC_INTENT_KEYWORDS + ) + if not has_intent: + return False + has_numeric = bool( + re.search( + r'[\$\u20ac\u00a3]\s?\d' + r'|(? str: + """Classify *query* as ``simple``, ``moderate``, or ``complex``. - if not keywords_json: - keyword_sets.append({}) - continue + Used by DEEP mode to decide whether to invoke the heavier + section-map structured reasoning pipeline or go straight to + cluster-level synthesis. + """ + if any(p.search(query) for p in cls._COMPLEX_QUERY_PATTERNS): + return "complex" + if any(p.search(query) for p in cls._MODERATE_QUERY_PATTERNS): + return "moderate" + return "simple" - # Try to parse as dict format - try: - keywords_dict = json.loads(keywords_json) - except json.JSONDecodeError: - try: - keywords_dict = ast.literal_eval(keywords_json) - except Exception: - keyword_sets.append({}) - continue + _VALID_COMPLEXITIES = frozenset({"simple", "moderate", "complex"}) + _VALID_INTENTS = frozenset({"lookup", "computation", "comparison"}) - # Validate using Pydantic model - try: - validated = KeywordValidation(root=keywords_dict).model_dump() - keyword_sets.append(validated) - except Exception: - keyword_sets.append({}) + async def _classify_query_intent( + self, query: str, + ) -> Tuple[str, str]: + """Classify query complexity and intent via LLM. - return keyword_sets + Falls back to regex-based ``_classify_query_complexity`` when the + LLM call fails or returns unparseable output. - @staticmethod - def _extract_alt_keywords(llm_resp: str) -> Dict[str, float]: - """Extract cross-lingual keywords from ```` block.""" - fields = extract_fields(content=llm_resp, tags=["KEYWORDS_ALT"]) - raw = fields.get("keywords_alt") - if not raw: - return {} + Returns: + ``(complexity, intent)`` where complexity is + ``simple|moderate|complex`` and intent is + ``lookup|computation|comparison``. + """ try: - parsed = json.loads(raw) - if isinstance(parsed, dict): - return {k: float(v) for k, v in parsed.items() if isinstance(k, str)} - except (json.JSONDecodeError, TypeError, ValueError): + from sirchmunk.llm.prompts import DEEP_QUERY_CLASSIFY + + resp = await self.llm.achat( + messages=[{ + "role": "user", + "content": DEEP_QUERY_CLASSIFY.format(query=query), + }], + stream=True, + ) + self.llm_usages.append(resp.usage) + + raw = (resp.content or "").strip() + data = self._extract_json_object(raw) + if data: + complexity = data.get("complexity", "").lower() + intent = data.get("intent", "").lower() + if (complexity in self._VALID_COMPLEXITIES + and intent in self._VALID_INTENTS): + return complexity, intent + except Exception as exc: + await self._logger.warning( + f"[QueryClassify] LLM classification failed: {exc}, " + f"falling back to regex" + ) + + complexity = self._classify_query_complexity(query) + intent = "computation" if complexity != "simple" else "lookup" + return complexity, intent + + @staticmethod + def _extract_json_object(raw: str) -> Optional[dict]: + """Extract the outermost JSON object from LLM response text.""" + start = raw.find("{") + end = raw.rfind("}") + if start >= 0 and end > start: try: - parsed = ast.literal_eval(raw) - if isinstance(parsed, dict): - return {k: float(v) for k, v in parsed.items() if isinstance(k, str)} - except Exception: + return json.loads(raw[start : end + 1]) + except (json.JSONDecodeError, TypeError): pass - return {} - - # ------------------------------------------------------------------ - # Agentic (ReAct) infrastructure — lazy initialisation - # ------------------------------------------------------------------ + return None - def _ensure_tool_registry( - self, - paths: List[str], - enable_dir_scan: bool = False, - max_depth: Optional[int] = 5, - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - ) -> "ToolRegistry": - """Build (or rebuild) the tool registry for the given search paths. + @staticmethod + def _extract_json_array(raw: str) -> Optional[list]: + """Extract the outermost JSON array from LLM response text.""" + start = raw.find("[") + end = raw.rfind("]") + if start >= 0 and end > start: + try: + return json.loads(raw[start : end + 1]) + except (json.JSONDecodeError, TypeError): + pass + return None - The registry is cached on ``self._tool_registry`` and re-created - only when ``paths`` change (detected via sorted hash). + @staticmethod + def _evaluate_evidence_acceptance( + query: str, + evidence: str, + llm_should_answer: bool, + *, + retrieval_complete: bool = False, + ) -> Tuple[bool, str]: + """Multi-factor decision on whether to accept retrieved evidence. - Args: - paths: Normalised list of path strings. - enable_dir_scan: Whether to include the directory-scan tool. - max_depth: Maximum directory depth for keyword search. - include: File patterns to include (glob). - exclude: File patterns to exclude (glob). + Combines the LLM's own SHOULD_ANSWER judgment with heuristic + signals (evidence length, keyword coverage, numeric-data presence) + to reduce false-negative rejections of valid evidence. Returns: - Ready-to-use ToolRegistry. + A tuple of (*accept*, *reason*) where *accept* is the final + boolean decision and *reason* is a human-readable string + documenting which factor(s) determined the outcome. """ - from sirchmunk.agentic.tools import ( - FileReadTool, - KeywordSearchTool, - KnowledgeQueryTool, - ToolRegistry, + # Factor 0: Agentic retrieval confirmed data completeness + if retrieval_complete: + return True, "retrieval_complete" + + # Factor 1: LLM direct acceptance + if llm_should_answer: + return True, "llm_accepted" + + # Factor 2: Heuristic override — length + keyword coverage + evidence_len = len(evidence) if evidence else 0 + kw_coverage = ( + AgenticSearch._compute_keyword_coverage(query, evidence) + if evidence else 0.0 ) - # Cache key: paths + filter params (all affect tool behaviour) - cache_key = ( - tuple(sorted(paths)), - max_depth, - tuple(include) if include else None, - tuple(exclude) if exclude else None, - ) if ( - self._tool_registry is not None - and getattr(self, "_tool_registry_key", None) == cache_key + evidence_len >= AgenticSearch._EVIDENCE_MIN_ACCEPT_LENGTH + and kw_coverage >= AgenticSearch._EVIDENCE_KEYWORD_COVERAGE_THRESHOLD ): - return self._tool_registry - - registry = ToolRegistry() - - # Tool 1: Knowledge cache (zero cost) - registry.register(KnowledgeQueryTool(self.knowledge_storage)) + return True, ( + f"heuristic_override(len={evidence_len}, " + f"kw_coverage={kw_coverage:.2f})" + ) - # Tool 2: Keyword search (low cost) - registry.register( - KeywordSearchTool( - retriever=self.grep_retriever, - paths=paths, - max_depth=max_depth if max_depth is not None else 5, - max_results=10, - include=include, - exclude=exclude, + # Factor 3: Numeric evidence detection + if AgenticSearch._detect_numeric_evidence(query, evidence or ""): + return True, ( + f"numeric_evidence(len={evidence_len}, " + f"kw_coverage={kw_coverage:.2f})" ) - ) - # Tool 3: File read (medium cost) - registry.register(FileReadTool(max_chars_per_file=30000)) + # All factors negative + return False, ( + f"rejected(llm=false, len={evidence_len}, " + f"kw_coverage={kw_coverage:.2f}, numeric=false)" + ) - # Tool 4: Directory scan (optional, medium cost) - if enable_dir_scan: - from sirchmunk.agentic.dir_scan_tool import DirScanTool - from sirchmunk.scan.dir_scanner import DirectoryScanner + # ------------------------------------------------------------------ + # Plan E: Computation verification + # ------------------------------------------------------------------ - if self._dir_scanner is None: - self._dir_scanner = DirectoryScanner(llm=self.llm, max_files=500) - registry.register(DirScanTool( - scanner=self._dir_scanner, - paths=paths, - )) + # ------------------------------------------------------------------ + # Plan D: Evidence adequacy closed-loop + # ------------------------------------------------------------------ - self._tool_registry = registry - self._tool_registry_key = cache_key - return registry - - # ------------------------------------------------------------------ - # Unified search entry point - # ------------------------------------------------------------------ - - async def search( + async def _check_evidence_completeness( self, query: str, - paths: Optional[Union[str, Path, List[str], List[Path]]] = None, - *, - mode: Literal["DEEP", "FAST", "FILENAME_ONLY"] = "FAST", - max_loops: int = 10, - max_token_budget: int = 128000, - max_depth: Optional[int] = 5, - top_k_files: int = 5, - enable_dir_scan: bool = False, - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - return_context: bool = False, - spec_stale_hours: float = 72.0, - chat_history: Optional[List[Dict[str, str]]] = None, - llm_fallback: bool = False, - ) -> Union[str, SearchContext, List[Dict[str, Any]]]: - """Perform intelligent search with multi-mode support. + intent: str, + evidence: str, + ) -> Tuple[bool, List[str]]: + """Check if evidence contains all data points needed for the query. - Modes: - +--------------+-------------------+-------------------------------------------+ - | Mode | Speed / LLM Calls | Description | - +--------------+-------------------+-------------------------------------------+ - | FILENAME_ONLY| Very Fast / 0 | Pattern-based file discovery, no LLM. | - | FAST | 1-5s / 0-2 | Greedy: cluster reuse or keyword search | - | | | → best file → answer. Early termination. | - | DEEP | 5-30s / 4-6 | Parallel multi-path retrieval + ReAct | - | | | refinement with Monte-Carlo evidence. | - +--------------+-------------------+-------------------------------------------+ + Returns: + ``(is_complete, missing)`` where *missing* lists descriptions + of data points not found in the evidence. + """ + try: + from sirchmunk.llm.prompts import EVIDENCE_COMPLETENESS_CHECK - FAST architecture (greedy early-termination): + prompt = EVIDENCE_COMPLETENESS_CHECK.format( + query=query, + intent=intent, + evidence_excerpt=evidence[:3000], + ) + resp = await self.llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=True, + ) + self.llm_usages.append(resp.usage) - ┌──────────────────────────────────────────────────────────┐ - │ Step 0 Cluster reuse check (instant short-circuit) │ - ├──────────────────────────────────────────────────────────┤ - │ Step 1 LLM query analysis → keywords + file hints │ - │ (single call, stream=False) │ - ├──────────────────────────────────────────────────────────┤ - │ Step 2 rga keyword search → ranked file hits + snippets │ - │ (no LLM, greedy: take first good results) │ - ├──────────────────────────────────────────────────────────┤ - │ Step 3 Read top file(s) content │ - │ (no LLM, early termination at top_k_files) │ - ├──────────────────────────────────────────────────────────┤ - │ Step 4 LLM answer synthesis from evidence │ - └──────────────────────────────────────────────────────────┘ + raw = (resp.content or "").strip() + match = re.search(r'\{[^}]+\}', raw, re.DOTALL) + if match: + data = json.loads(match.group()) + is_complete = bool(data.get("complete", True)) + missing = data.get("missing", []) + if isinstance(missing, list) and missing: + return False, [str(m) for m in missing[:5]] + return is_complete, [] + except Exception as exc: + await self._logger.warning( + f"[Phase 3.75] Completeness check failed: {exc}" + ) + return True, [] - DEEP architecture (phases execute as parallel as possible): + async def _fill_evidence_gaps( + self, + query: str, + missing: List[str], + file_paths: List[str], + artifacts: Any, + *, + scope: Optional["_PathScope"] = None, + nav_cache: Optional[Dict[str, str]] = None, + ) -> Optional[str]: + """Targeted evidence retrieval for identified gaps. - ┌──────────────────────────────────────────────────────────┐ - │ Phase 0a Direct document analysis (intent-gated, │ - │ short-circuit if query is doc-level operation) │ - ├──────────────────────────────────────────────────────────┤ - │ Phase 0 Cluster reuse check (instant, short-circuit) │ - ├──────────────────────────────────────────────────────────┤ - │ Phase 1 Parallel probing (all concurrent): │ - │ ├─ LLM keyword extraction │ - │ ├─ DirectoryScanner.scan() (filesystem only, fast) │ - │ ├─ Knowledge cache similarity search │ - │ └─ Spec-path cache load │ - ├──────────────────────────────────────────────────────────┤ - │ Phase 2 Parallel retrieval (depends on Phase 1): │ - │ ├─ keyword_search per extracted keyword (concurrent rga)│ - │ └─ DirectoryScanner.rank() (LLM ranks candidates) │ - ├──────────────────────────────────────────────────────────┤ - │ Phase 3 Merge + evidence assembly: │ - │ └─ knowledge_base.build() (parallel per-file Monte │ - │ Carlo evidence sampling) │ - ├──────────────────────────────────────────────────────────┤ - │ Phase 4 Summary / ReAct refinement: │ - │ └─ If evidence sufficient → LLM summary │ - │ Else → ReAct loop for adaptive follow-up │ - ├──────────────────────────────────────────────────────────┤ - │ Phase 5 Persistence (concurrent, awaited): │ - │ ├─ Save cluster + embeddings │ - │ └─ Save spec-path cache │ - └──────────────────────────────────────────────────────────┘ + Constructs focused sub-queries from *missing* descriptions and + re-navigates tree indices or falls back to keyword retrieval. - Args: - query: User's search query. - paths: Directories / files to search. Falls back to - ``self.paths`` or the current working directory. - mode: Search mode — ``"DEEP"``, ``"FAST"``, or ``"FILENAME_ONLY"``. - max_loops: Maximum ReAct iterations (DEEP mode, default: 10). - max_token_budget: LLM token budget (DEEP mode, default: 128000). - max_depth: Maximum directory depth for file search (default: 5). - Used in both FILENAME_ONLY and DEEP modes. - top_k_files: Max files for evidence extraction (default: 5). - enable_dir_scan: Enable directory scanning (FAST and DEEP modes). - include: File glob patterns to include (e.g. ``["*.py", "*.md"]``). - Used in both FILENAME_ONLY and DEEP modes. - exclude: File glob patterns to exclude (e.g. ``["*.log"]``). - Used in both FILENAME_ONLY and DEEP modes. - return_context: If True, return a ``SearchContext`` object - that carries ``answer``, ``cluster`` (KnowledgeCluster), - and full pipeline telemetry (LLM usage, files read, etc.). - spec_stale_hours: Hours before spec cache is stale (default: 72). - chat_history: Optional list of chat messages for context (DEEP mode). - llm_fallback: When True, if no relevant documents are found, - the LLM will attempt to answer the query from its own - knowledge. Default False. + When *scope* is provided, extra files drawn from + ``artifacts.tree_available_paths`` are filtered to the scope. + When *nav_cache* is provided, navigation results are cached to + avoid duplicate LLM calls across phases. - Returns: - - ``str``: Answer summary (default). - - ``SearchContext``: If *return_context* — contains ``answer``, - ``cluster``, and telemetry in a single object. - - ``List[Dict]``: File matches in FILENAME_ONLY mode. + Returns supplementary evidence text, or None. """ - paths = self.validate_search_paths( - self._resolve_paths(paths), - ) - if not paths: - msg = "No valid search paths remain after validation." - _loguru_logger.warning(msg) - if return_context: - ctx = SearchContext() - ctx.answer = msg - return ctx - return msg + sub_query = f"{query} — specifically: {'; '.join(missing)}" + parts: List[str] = [] - # ---- Chat intent short-circuit (rule-based, no LLM cost) ---- - if mode != "FILENAME_ONLY" and self._is_chat_query(query): - answer, cluster, ctx = await self._respond_chat(query, chat_history=chat_history) - if return_context: - ctx.answer = answer - return ctx - return answer + async def _navigate(fp: str, q: str) -> Optional[str]: + if nav_cache is not None: + return await self._cached_navigate_tree(fp, q, nav_cache) + return await self._navigate_tree_for_evidence(fp, q) - # ---- FILENAME_ONLY: pattern-based file discovery, no LLM ---- - if mode == "FILENAME_ONLY": - results = await self._search_by_filename( - query=query, paths=paths, max_depth=max_depth, - include=include, exclude=exclude, top_k=top_k_files, - ) - if not results: - msg = f"No files found matching query: '{query}'" - await self._logger.warning(msg) - return msg - await self._logger.success(f"Retrieved {len(results)} matching files") - return results + indexer = self._get_tree_indexer() + for fp in file_paths[:3]: + try: + if indexer and indexer.has_tree(fp): + ev = await _navigate(fp, sub_query) + if ev and len(ev.strip()) > 100: + parts.append( + f"[Gap-fill: {Path(fp).name}]\n{ev}" + ) + continue + ev = await self._tree_guided_sample(fp, sub_query) + if isinstance(ev, str) and len(ev.strip()) > 100: + parts.append(f"[Gap-fill: {Path(fp).name}]\n{ev}") + except Exception: + continue - # ---- FAST / DEEP → both produce (answer, cluster, context) ---- - if mode == "FAST": - answer, cluster, context = await self._search_fast( - query=query, paths=paths, max_depth=max_depth, - top_k_files=top_k_files, enable_dir_scan=enable_dir_scan, - include=include, exclude=exclude, - llm_fallback=llm_fallback, - ) - else: - answer, cluster, context = await self._search_deep( - query=query, paths=paths, - max_loops=max_loops, max_token_budget=max_token_budget, - max_depth=max_depth, top_k_files=top_k_files, - enable_dir_scan=enable_dir_scan, - include=include, exclude=exclude, - spec_stale_hours=spec_stale_hours, - llm_fallback=llm_fallback, - ) + if not parts and artifacts and artifacts.tree_available_paths: + extra_fps = [ + fp for fp in artifacts.tree_available_paths + if fp not in file_paths + and (not scope or scope.contains(fp)) + ][:2] + for fp in extra_fps: + try: + ev = await _navigate(fp, sub_query) + if ev and len(ev.strip()) > 100: + parts.append( + f"[Gap-fill extra: {Path(fp).name}]\n{ev}" + ) + except Exception: + continue - # ---- Unified return wrapping ---- - if return_context: - prefix = "FS" if mode == "FAST" else "DS" - context.answer = answer - if (answer or "").strip().lower() == _NO_RESULTS_MESSAGE.lower(): - context.cluster = cluster - return context - # Use read_file_ids from context if available, otherwise empty - fallback_files = list(context.read_file_ids) if context.read_file_ids else None - context.cluster = cluster or self._make_answer_cluster( - query, answer, prefix, file_paths=fallback_files, - ) - return context - return answer + if not parts: + return None + return "\n\n".join(parts) # ------------------------------------------------------------------ - # DEEP mode — parallel multi-path retrieval with ReAct fallback + # Plan E: Computation verification # ------------------------------------------------------------------ - async def _search_deep( + _ARITH_PATTERNS = [ + re.compile( + r'[\$€£]?\s*' + r'([\d,]+(?:\.\d+)?)\s*' + r'([+\-\*/])\s*' + r'[\$€£]?\s*' + r'([\d,]+(?:\.\d+)?)\s*' + r'=\s*' + r'[\$€£]?\s*' + r'([\-]?[\d,]+(?:\.\d+)?)\s*%?' + ), + re.compile( + r'\(\s*' + r'[\$€£]?\s*([\d,]+(?:\.\d+)?)\s*' + r'([+\-])\s*' + r'[\$€£]?\s*([\d,]+(?:\.\d+)?)\s*' + r'\)\s*[/\*]\s*' + r'[\$€£]?\s*([\d,]+(?:\.\d+)?)\s*' + r'=\s*' + r'[\$€£]?\s*([\-]?[\d,]+(?:\.\d+)?)\s*%?' + ), + ] + + _SAFE_EVAL_NS: Dict[str, Any] = {"__builtins__": {}, "abs": abs, "round": round} + _ARITH_TOLERANCE: float = 0.01 + + @classmethod + def _extract_arithmetic_expressions(cls, text: str) -> List[Dict[str, Any]]: + """Extract arithmetic expressions and their stated results from text. + + Returns list of ``{"expr": str, "stated": float, "computed": float}``. + Only includes entries where Python evaluation succeeded. + """ + results: List[Dict[str, Any]] = [] + + def _parse_num(s: str) -> float: + return float(s.replace(",", "")) + + for line in text.split("\n"): + for pat in cls._ARITH_PATTERNS: + for m in pat.finditer(line): + groups = m.groups() + try: + if len(groups) == 4: + a, op, b, stated = groups + a_val, b_val = _parse_num(a), _parse_num(b) + expr = f"{a_val} {op} {b_val}" + computed = eval(expr, cls._SAFE_EVAL_NS) + results.append({ + "expr": expr, + "stated": _parse_num(stated), + "computed": float(computed), + "raw": m.group(), + }) + elif len(groups) == 5: + a, op, b, divisor, stated = groups + a_val, b_val = _parse_num(a), _parse_num(b) + d_val = _parse_num(divisor) + inner = f"{a_val} {op} {b_val}" + inner_result = eval(inner, cls._SAFE_EVAL_NS) + op2 = "/" if "/" in line[m.start():m.end()] else "*" + computed = eval( + f"{inner_result} {op2} {d_val}", + cls._SAFE_EVAL_NS, + ) + results.append({ + "expr": f"({inner}) {op2} {d_val}", + "stated": _parse_num(stated), + "computed": float(computed), + "raw": m.group(), + }) + except Exception: + continue + return results + + async def _verify_computation( self, query: str, - paths: List[str], - *, - max_loops: int = 10, - max_token_budget: int = 128000, - max_depth: Optional[int] = 5, - top_k_files: int = 5, - enable_dir_scan: bool = False, - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - spec_stale_hours: float = 72.0, - llm_fallback: bool = False, - ) -> Tuple[str, Optional[KnowledgeCluster], SearchContext]: - """Parallel multi-path retrieval pipeline (Phases 0a–5). + answer: str, + ) -> Tuple[str, bool]: + """Verify arithmetic in computation-type answers. + + Extracts arithmetic expressions, evaluates them with Python, and + re-prompts the LLM if a discrepancy is detected. Returns: - ``(answer, cluster, context)`` tuple. + ``(corrected_answer, was_corrected)``. """ - context = SearchContext( - max_token_budget=max_token_budget, - max_loops=max_loops, - ) - _llm_usage_start = len(self.llm_usages) + expressions = self._extract_arithmetic_expressions(answer) + if not expressions: + return answer, False + + discrepancies = [] + for expr_info in expressions: + stated = expr_info["stated"] + computed = expr_info["computed"] + if stated == 0 and computed == 0: + continue + denom = max(abs(stated), abs(computed), 1e-9) + if abs(stated - computed) / denom > self._ARITH_TOLERANCE: + discrepancies.append(expr_info) - # ============================================================== - # Phase 0a: Direct document analysis (intent-gated short-circuit) - # ============================================================== - direct = await self._try_direct_doc_analysis(query, paths) - if direct is not None: - return direct, self._make_answer_cluster(query, direct, "DQ", file_paths=paths), context + if not discrepancies: + return answer, False - # ============================================================== - # Phase 0: Cluster reuse (instant short-circuit) - # When reuse_knowledge=True and a similar cluster is found, we - # return here — Phase 5 (Persistence) is not executed for that path. - # ============================================================== - reused = await self._try_reuse_cluster(query, paths) - if reused is not None: - content = reused.content - if isinstance(content, list): - content = "\n".join(content) - return str(content), reused, context + worst = max( + discrepancies, + key=lambda d: abs(d["stated"] - d["computed"]), + ) - await self._logger.info(f"[search] Starting multi-path retrieval for: '{query[:80]}'") + await self._logger.info( + f"[Phase 4.5:Verify] Arithmetic discrepancy: " + f"{worst['expr']} = {worst['stated']} (stated) vs " + f"{worst['computed']} (computed)" + ) - # ============================================================== - # Phase 1: Parallel probing — all four paths fire concurrently - # ============================================================== - await self._logger.info("[Phase 1] Parallel probing: keywords + dir_scan + knowledge + spec_cache") - context.increment_loop() + try: + from sirchmunk.llm.prompts import COMPUTATION_CORRECTION - phase1_results = await asyncio.gather( - self._probe_keywords(query), - self._probe_dir_scan(paths, enable_dir_scan), - self._probe_knowledge_cache(query), - self._load_spec_context(paths, stale_hours=spec_stale_hours), - return_exceptions=True, - ) + correction_prompt = COMPUTATION_CORRECTION.format( + query=query, + original_answer=answer[:3000], + expression=worst["expr"], + llm_result=worst["stated"], + correct_result=worst["computed"], + ) + resp = await self.llm.achat( + messages=[{"role": "user", "content": correction_prompt}], + stream=True, + ) + self.llm_usages.append(resp.usage) - kw_result = phase1_results[0] if not isinstance(phase1_results[0], Exception) else ({}, []) - scan_result = phase1_results[1] if not isinstance(phase1_results[1], Exception) else None - knowledge_hits = phase1_results[2] if not isinstance(phase1_results[2], Exception) else [] - spec_context = phase1_results[3] if not isinstance(phase1_results[3], Exception) else "" + corrected = resp.content or "" + if corrected and len(corrected) > 100: + await self._logger.info( + "[Phase 4.5:Verify] Correction applied" + ) + return corrected, True + except Exception as exc: + await self._logger.warning( + f"[Phase 4.5:Verify] Correction failed: {exc}" + ) - for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache"]): - if isinstance(phase1_results[i], Exception): - await self._logger.warning(f"[Phase 1] {label} probe failed: {phase1_results[i]}") + return answer, False - query_keywords, initial_keywords = kw_result if isinstance(kw_result, tuple) else ({}, []) + @staticmethod + def _extract_and_validate_multi_level_keywords( + llm_resp: str, + num_levels: int = 3 + ) -> List[Dict[str, float]]: + """ + Extract and validate multiple sets of keywords from LLM response. - await self._logger.info( - f"[Phase 1] Results: keywords={len(initial_keywords)}, " - f"dir_scan={'OK' if scan_result else 'N/A'}, " - f"knowledge_hits={len(knowledge_hits)}, " - f"spec_cache={'YES' if spec_context else 'NO'}" - ) + Args: + llm_resp: LLM response containing keyword sets + num_levels: Number of keyword granularity levels to extract - # ============================================================== - # Phase 2: Parallel retrieval — keyword search + dir_scan rank - # ============================================================== - await self._logger.info("[Phase 2] Parallel retrieval: rga keyword search + dir_scan LLM rank") - context.increment_loop() + Returns: + List of keyword dicts, one for each level: [level1_keywords, level2_keywords, ...] + """ + keyword_sets: List[Dict[str, float]] = [] - phase2_tasks = [] + # Generate tags dynamically based on num_levels + tags = [f"KEYWORDS_LEVEL_{i + 1}" for i in range(num_levels)] - if initial_keywords: - phase2_tasks.append( - self._retrieve_by_keywords( - initial_keywords, paths, - max_depth=max_depth, include=include, exclude=exclude, - ) - ) - else: - phase2_tasks.append(self._async_noop([])) + # Extract all fields at once + extracted_fields = extract_fields(content=llm_resp, tags=tags) - if scan_result is not None and enable_dir_scan: - phase2_tasks.append( - self._rank_dir_scan_candidates(query, scan_result) - ) - else: - phase2_tasks.append(self._async_noop([])) + for level_idx, tag in enumerate(tags, start=1): + keywords_dict: Dict[str, float] = {} + keywords_json: Optional[str] = extracted_fields.get(tag.lower(), None) + + if not keywords_json: + keyword_sets.append({}) + continue - phase2_results = await asyncio.gather(*phase2_tasks, return_exceptions=True) + # Try to parse as dict format + try: + keywords_dict = json.loads(keywords_json) + except json.JSONDecodeError: + try: + keywords_dict = ast.literal_eval(keywords_json) + except Exception: + keyword_sets.append({}) + continue - keyword_files = phase2_results[0] if not isinstance(phase2_results[0], Exception) else [] - dir_scan_files = phase2_results[1] if not isinstance(phase2_results[1], Exception) else [] + # Validate using Pydantic model + try: + validated = KeywordValidation(root=keywords_dict).model_dump() + keyword_sets.append(validated) + except Exception: + keyword_sets.append({}) - for i, label in enumerate(["keyword_search", "dir_scan_rank"]): - if isinstance(phase2_results[i], Exception): - await self._logger.warning(f"[Phase 2] {label} failed: {phase2_results[i]}") + return keyword_sets - await self._logger.info( - f"[Phase 2] Results: keyword_files={len(keyword_files)}, " - f"dir_scan_files={len(dir_scan_files)}" + @staticmethod + def _extract_alt_keywords(llm_resp: str) -> Dict[str, float]: + """Extract cross-lingual keywords from ```` block.""" + fields = extract_fields(content=llm_resp, tags=["KEYWORDS_ALT"]) + raw = fields.get("keywords_alt") + if not raw: + return {} + try: + parsed = json.loads(raw) + if isinstance(parsed, dict): + return {k: float(v) for k, v in parsed.items() if isinstance(k, str)} + except (json.JSONDecodeError, TypeError, ValueError): + try: + parsed = ast.literal_eval(raw) + if isinstance(parsed, dict): + return {k: float(v) for k, v in parsed.items() if isinstance(k, str)} + except Exception: + pass + return {} + + # ------------------------------------------------------------------ + # Agentic (ReAct) infrastructure — lazy initialisation + # ------------------------------------------------------------------ + + def _ensure_tool_registry( + self, + paths: List[str], + enable_dir_scan: bool = False, + max_depth: Optional[int] = 5, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ) -> "ToolRegistry": + """Build (or rebuild) the tool registry for the given search paths. + + The registry is cached on ``self._tool_registry`` and re-created + only when ``paths`` change (detected via sorted hash). + + Args: + paths: Normalised list of path strings. + enable_dir_scan: Whether to include the directory-scan tool. + max_depth: Maximum directory depth for keyword search. + include: File patterns to include (glob). + exclude: File patterns to exclude (glob). + + Returns: + Ready-to-use ToolRegistry. + """ + from sirchmunk.agentic.tools import ( + FileReadTool, + KeywordSearchTool, + KnowledgeQueryTool, + TreeNavigationTool, + ToolRegistry, ) - # ============================================================== - # Phase 3: Merge file paths + build KnowledgeCluster - # ============================================================== - context.increment_loop() - merged_files = self._merge_file_paths( - keyword_files=keyword_files, - dir_scan_files=dir_scan_files, - knowledge_hits=knowledge_hits, + # Cache key: paths + filter params (all affect tool behaviour) + cache_key = ( + tuple(sorted(paths)), + max_depth, + tuple(include) if include else None, + tuple(exclude) if exclude else None, ) - await self._logger.info(f"[Phase 3] Merged {len(merged_files)} unique candidate files") + if ( + self._tool_registry is not None + and getattr(self, "_tool_registry_key", None) == cache_key + ): + return self._tool_registry - cluster: Optional[KnowledgeCluster] = None - if merged_files: - cluster = await self._build_cluster( - query=query, file_paths=merged_files, - query_keywords=query_keywords, top_k_files=top_k_files, - ) + registry = ToolRegistry() - # ============================================================== - # Phase 4: Generate answer — cluster summary or ReAct refinement - # ============================================================== - context.increment_loop() - answer: str = "" - should_save: bool = True + # Tool 1: Knowledge cache (zero cost) + registry.register(KnowledgeQueryTool(self.knowledge_storage)) - if cluster and cluster.content: - await self._logger.info("[Phase 4] Evidence sufficient, generating summary") - answer, should_save, should_answer = await self._summarise_cluster(query, cluster) - if not should_answer: - if llm_fallback: - await self._logger.info( - "[Phase 4] Summary gate rejected evidence, llm_fallback=True → LLM fallback" - ) - answer, should_save = await self._summarise_cluster_fallback(query) - else: - await self._logger.warning( - "[Phase 4] Summary gate rejected evidence and llm_fallback=False " - "→ returning no results" - ) - return _NO_RESULTS_MESSAGE, None, context - if not cluster.search_results: - cluster.search_results = list(merged_files) - elif llm_fallback: - await self._logger.info( - "[Phase 4] Evidence insufficient, llm_fallback=True \u2192 LLM summary" - ) - answer, should_save = await self._summarise_cluster_fallback(query) - else: - await self._logger.info("[Phase 4] Evidence insufficient, launching ReAct refinement") - react_answer, context = await self._react_refinement( - query=query, paths=paths, - initial_keywords=initial_keywords, spec_context=spec_context, - enable_dir_scan=enable_dir_scan, - max_loops=max_loops, max_token_budget=max_token_budget, - max_depth=max_depth, include=include, exclude=exclude, + # Tool 2: Keyword search (low cost) + registry.register( + KeywordSearchTool( + retriever=self.grep_retriever, + paths=paths, + max_depth=max_depth if max_depth is not None else 5, + max_results=10, + include=include, + exclude=exclude, ) + ) - if not cluster: - cluster = await self._build_cluster_from_context( - query=query, answer=react_answer, context=context, - query_keywords=query_keywords, top_k_files=top_k_files, - ) - elif react_answer and not cluster.content: - cluster.content = react_answer - - if not cluster: - await self._logger.warning( - "[Phase 4] ReAct found no buildable evidence and llm_fallback=False " - "→ returning no results" - ) - return _NO_RESULTS_MESSAGE, None, context + # Tool 3: File read (medium cost) + registry.register(FileReadTool(max_chars_per_file=30000)) - # Final DEEP decision is always made in the summary call. - answer, should_save, should_answer = await self._summarise_cluster(query, cluster) - if not should_answer: - if llm_fallback: - await self._logger.info( - "[Phase 4] Final summary gate rejected evidence, llm_fallback=True → LLM fallback" - ) - answer, should_save = await self._summarise_cluster_fallback(query) - else: - await self._logger.warning( - "[Phase 4] Final summary gate rejected evidence and llm_fallback=False " - "→ returning no results" - ) - return _NO_RESULTS_MESSAGE, None, context + # Tool 4: Directory scan (optional, medium cost) + if enable_dir_scan: + from sirchmunk.agentic.dir_scan_tool import DirScanTool + from sirchmunk.scan.dir_scanner import DirectoryScanner - # Sync LLM token accounting into context - new_usages = self.llm_usages[_llm_usage_start:] - for usage in new_usages: - if usage and isinstance(usage, dict): - total_tok = usage.get("total_tokens", 0) - if total_tok == 0: - total_tok = usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0) - context.add_llm_tokens(total_tok, usage=usage) + if self._dir_scanner is None: + self._dir_scanner = DirectoryScanner( + llm=self.llm, max_files=500, + ) + registry.register(DirScanTool( + scanner=self._dir_scanner, + paths=paths, + )) - # ============================================================== - # Phase 5: Persistence (quality-gated) - # Skipped when Phase 4 quality check says the answer is low-quality - # or when Phase 0 reused a cluster (early-returned above). - # ============================================================== - phase5_tasks = [] - if cluster and should_save: - self._add_query_to_cluster(cluster, query) - phase5_tasks.append(self._save_cluster_with_embedding(cluster)) - elif not should_save: - await self._logger.info("[Phase 5] Quality gate: low-quality answer, skipping cluster save") - cluster = None - phase5_tasks.append(self._save_spec_context(paths, context, scan_result=scan_result)) - results = await asyncio.gather(*phase5_tasks, return_exceptions=True) - for r in results: - if isinstance(r, Exception): - _loguru_logger.warning(f"[Phase 5] Persistence task failed: {r}") + # Tool 5: Tree navigation (when compile artifacts exist) + artifacts = self._detect_compile_artifacts(paths) + if artifacts and artifacts.tree_available_paths: + registry.register(TreeNavigationTool( + navigate_fn=self._tree_guided_sample, + available_paths=artifacts.tree_available_paths, + max_chars=self._FAST_MAX_EVIDENCE_CHARS, + )) - await self._logger.success(f"[search] Complete: {context.summary()}") - return answer, cluster, context + self._tool_registry = registry + self._tool_registry_key = cache_key + return registry # ------------------------------------------------------------------ - # Phase 0a: Direct document analysis (intent-gated) + # Knowledge compile entry point # ------------------------------------------------------------------ - async def _try_direct_doc_analysis( + async def compile( self, - query: str, - paths: List[str], - ) -> Optional[str]: - """Short-circuit for document-level queries (e.g. "请总结这篇文档"). + paths: Optional[Union[str, Path, List[str], List[Path]]] = None, + *, + incremental: bool = True, + shallow: bool = False, + max_files: Optional[int] = None, + concurrency: int = 3, + ) -> Dict[str, Any]: + """Compile document collections into structured knowledge indices. - Uses the LLM to classify query intent (language-agnostic). When - a whole-document operation is detected **and** suitable files exist - in *paths*, their content is fed directly to the LLM — bypassing - the heavyweight keyword / dir-scan / evidence pipeline. + Optional offline pre-processing step that builds tree indices and + knowledge clusters. Products are automatically leveraged by + subsequent search() calls. + + Args: + paths: Directories or files to compile. Falls back to self.paths. + incremental: Skip unchanged files (default True). + shallow: Skip tree building — use direct LLM summarisation only. + max_files: Cap on files — triggers importance sampling for large sets. + concurrency: Max parallel file compilations. Returns: - LLM answer string, or None if the short-circuit does not apply. + CompileReport as a dict. """ - from sirchmunk.doc_qa import ( - detect_doc_intent, - collect_doc_files, - analyse_documents, - ) - - # Step 1: file gate — skip early if paths contain no loadable docs - doc_files = collect_doc_files(paths) - if not doc_files: - return None - - # Step 2: LLM intent classification (cheap, stream=False) - operation = await detect_doc_intent(query, self.llm, self.llm_usages) - if operation is None: - return None - - filenames = ", ".join(Path(d.path).name for d in doc_files) - await self._logger.info( - f"[DocQA] Intent '{operation}' detected — " - f"loading {len(doc_files)} file(s) for direct analysis: {filenames}" - ) - - # Step 3: for summary operations, use the chunked summarizer - # with optional smart dir scanning; for other operations, use the - # general analyser. - if operation in ("summarize", "summary", "extract"): - scan_result = None - if self._has_directory_paths(paths): - scan_result = await self._probe_dir_scan(paths, max_files=300) - answer = await self._summarize_documents( - query, paths, scan_result=scan_result, - ) - else: - answer = await analyse_documents( - query=query, - doc_files=doc_files, - llm=self.llm, - llm_usages=self.llm_usages, - ) - - if answer: - await self._logger.success("[DocQA] Direct document analysis complete") - return answer - - # ------------------------------------------------------------------ - # Chat intent detection — short-circuit for non-search queries - # ------------------------------------------------------------------ - - @staticmethod - def _is_chat_query(query: str) -> bool: - """Return True for obvious conversational queries (rule-based, no LLM).""" - return bool(_CHAT_QUERY_RE.match(query.strip())) + from sirchmunk.learnings.compiler import KnowledgeCompiler + from sirchmunk.learnings.tree_indexer import DocumentTreeIndexer - async def _respond_chat( - self, - query: str, - context: Optional[SearchContext] = None, - *, - chat_history: Optional[List[Dict[str, str]]] = None, - ) -> Tuple[str, Optional[KnowledgeCluster], SearchContext]: - """Generate a direct conversational response (single LLM call, no retrieval).""" + resolved = self._resolve_paths(paths) await self._logger.info( - f"[search] Chat intent detected — responding directly: '{query[:60]}'" + f"[Compile] Starting compile for {len(resolved)} path(s)" ) - ctx = context or SearchContext() - messages = [ - {"role": "system", "content": _CHAT_RESPONSE_SYSTEM}, - *(chat_history or []), - {"role": "user", "content": query}, - ] - resp = await self.llm.achat(messages=messages, stream=False) - self.llm_usages.append(resp.usage) - if resp.usage and isinstance(resp.usage, dict): - ctx.add_llm_tokens( - resp.usage.get("total_tokens", 0), usage=resp.usage, - ) - return resp.content or "", None, ctx - - # ------------------------------------------------------------------ - # Document summarization — shared by FAST & DEEP summary intent - # ------------------------------------------------------------------ - - _SUMMARY_MAX_CONTEXT_CHARS = 100_000 - _SUMMARY_CHUNK_CHARS = 50_000 - _SUMMARY_MAX_FILE_SIZE = 200 * 1024 * 1024 # 200 MB — sampling handles large files - - async def _summarize_documents( - self, - query: str, - paths: List[str], - *, - top_k_files: int = 5, - scan_result=None, - ) -> Optional[str]: - """Summarize documents from *paths* with smart content sampling. - - When *scan_result* (from a prior directory scan) is provided, the - LLM ranks candidates first so only the most relevant files are - summarized. Otherwise falls back to ``collect_doc_files``. - - Small files are loaded in full; large files are sampled (head + mid + - tail). When the total content exceeds the LLM context budget, the - documents are processed in chunks — each chunk is summarized - independently, then the partial summaries are merged in a final pass. - - Returns: - Summary string, or ``None`` if no documents could be loaded. - """ - from sirchmunk.doc_qa import collect_doc_files, _extract_text, _sample_text - - summary_paths: Optional[List[str]] = None - - # When a scan result is available, use LLM ranking to pick candidates - if scan_result is not None: - ranked = await self._rank_dir_scan_candidates( - query, scan_result, - top_k=top_k_files * 2, - include_medium=True, - ) - if ranked: - summary_paths = ranked[:top_k_files] - await self._logger.info( - f"[Summary] Dir scan selected {len(summary_paths)} relevant file(s)" - ) - doc_files = collect_doc_files( - summary_paths or paths, - max_files=top_k_files, - max_file_size=self._SUMMARY_MAX_FILE_SIZE, + tree_cache = self.work_path / ".cache" / "compile" / "trees" + _cb = getattr(self._logger, 'log_callback', None) + tree_indexer = DocumentTreeIndexer( + llm=self.llm, + cache_dir=tree_cache, + log_callback=_cb, ) - if not doc_files: - await self._logger.warning( - f"[Summary] No loadable documents found in paths: {paths}" - ) - return None - - doc_texts: List[Tuple[str, str]] = [] - total_chars = 0 - for df in doc_files: - text = await _extract_text(df) - if text: - fname = Path(df.path).name - doc_texts.append((fname, text)) - total_chars += len(text) - else: - await self._logger.warning( - f"[Summary] Text extraction failed for: {Path(df.path).name}" - ) - - if not doc_texts: - await self._logger.warning("[Summary] No text could be extracted from collected documents") - return None - await self._logger.info( - f"[Summary] Loaded {len(doc_texts)} doc(s), " - f"total {total_chars} chars" + compiler = KnowledgeCompiler( + llm=self.llm, + embedding_client=self.embedding_client, + knowledge_storage=self.knowledge_storage, + tree_indexer=tree_indexer, + work_path=self.work_path, + log_callback=_cb, ) - needs_sampling = total_chars > self._SUMMARY_MAX_CONTEXT_CHARS - per_file_budget = ( - self._SUMMARY_MAX_CONTEXT_CHARS // len(doc_texts) - if needs_sampling else 0 + report = await compiler.compile( + paths=resolved, + incremental=incremental, + shallow=shallow, + max_files=max_files, + concurrency=concurrency, ) - parts: List[str] = [] - for fname, text in doc_texts: - content = _sample_text(text, per_file_budget) if needs_sampling else text - parts.append(f"#### File: {fname}\n```\n{content}\n```") - - combined = "\n\n".join(parts) + return report.to_dict() - if len(combined) <= self._SUMMARY_CHUNK_CHARS: - return await self._llm_summarize_docs(combined, query) + async def compile_status( + self, + paths: Optional[Union[str, Path, List[str], List[Path]]] = None, + ) -> Dict[str, Any]: + """Return current compile status for the given paths.""" + from sirchmunk.learnings.compiler import KnowledgeCompiler + from sirchmunk.learnings.tree_indexer import DocumentTreeIndexer - return await self._llm_chunked_summarize(combined, query) + resolved = self._resolve_paths(paths) - async def _llm_summarize_docs(self, documents: str, query: str) -> str: - """Single-pass LLM summarization.""" - prompt = DOC_SUMMARY.format(documents=documents, user_input=query) - resp = await self.llm.achat( - messages=[{"role": "user", "content": prompt}], - stream=True, + tree_cache = self.work_path / ".cache" / "compile" / "trees" + tree_indexer = DocumentTreeIndexer( + llm=self.llm, cache_dir=tree_cache, ) - self.llm_usages.append(resp.usage) - return resp.content or "" - async def _llm_chunked_summarize(self, combined: str, query: str) -> str: - """Multi-pass chunked summarization for large content.""" - chunk_size = self._SUMMARY_CHUNK_CHARS - chunks = [ - combined[i:i + chunk_size] - for i in range(0, len(combined), chunk_size) - ] - await self._logger.info( - f"[Summary] Content exceeds single-pass limit — " - f"splitting into {len(chunks)} chunk(s)" + compiler = KnowledgeCompiler( + llm=self.llm, + embedding_client=self.embedding_client, + knowledge_storage=self.knowledge_storage, + tree_indexer=tree_indexer, + work_path=self.work_path, ) - partial_summaries: List[str] = [] - for idx, chunk in enumerate(chunks, 1): - await self._logger.info(f"[Summary] Summarizing chunk {idx}/{len(chunks)}") - prompt = DOC_CHUNK_SUMMARY.format(chunk=chunk, user_input=query) - resp = await self.llm.achat( - messages=[{"role": "user", "content": prompt}], - stream=True, - ) - self.llm_usages.append(resp.usage) - if resp.content: - partial_summaries.append(resp.content) + status = await compiler.get_status(resolved) + return { + "total_compiled_files": status.total_compiled_files, + "total_clusters": status.total_clusters, + "total_trees": status.total_trees, + "last_compile_at": status.last_compile_at, + "manifest_path": status.manifest_path, + } - if not partial_summaries: - return "" - if len(partial_summaries) == 1: - return partial_summaries[0] + async def compile_lint( + self, + *, + auto_fix: bool = False, + ) -> Dict[str, Any]: + """Run knowledge health checks and optionally auto-fix issues.""" + from sirchmunk.learnings.lint import KnowledgeLint - merged_input = "\n\n---\n\n".join( - f"**Part {i}**\n{s}" for i, s in enumerate(partial_summaries, 1) - ) - prompt = DOC_MERGE_SUMMARIES.format(summaries=merged_input, user_input=query) - resp = await self.llm.achat( - messages=[{"role": "user", "content": prompt}], - stream=True, + linter = KnowledgeLint( + knowledge_storage=self.knowledge_storage, + work_path=self.work_path, + log_callback=getattr(self._logger, 'log_callback', None), ) - self.llm_usages.append(resp.usage) - return resp.content or "" + + report = await linter.run(auto_fix=auto_fix) + return report.to_dict() # ------------------------------------------------------------------ - # FAST mode — greedy search with early termination + # Unified search entry point # ------------------------------------------------------------------ - _FAST_TEXT_EXTENSIONS = { - ".txt", ".md", ".rst", ".csv", ".log", ".tsv", - ".py", ".js", ".ts", ".json", ".yaml", ".yml", ".xml", - ".html", ".htm", ".sh", ".toml", ".cfg", ".ini", ".conf", - ".css", ".bash", ".java", ".c", ".cpp", ".h", ".go", ".rs", - } - _FAST_CONTEXT_WINDOW = 30 # ± lines around each grep hit - _FAST_MAX_EVIDENCE_CHARS = 15_000 - _FAST_SMALL_FILE_THRESHOLD = 100_000 # 100K chars - read full file instead of grep sampling - - _LLM_FALLBACK_EVIDENCE = ( - "[No relevant documents found]\n\n" - "The search did not find relevant content in the available documents. " - "Please answer the user's question based on your own knowledge. " - "Clearly indicate that this answer is from LLM knowledge, " - "not from retrieved documents." - ) - - async def _search_fast( + async def search( self, query: str, - paths: List[str], + paths: Optional[Union[str, Path, List[str], List[Path]]] = None, *, + mode: Literal["DEEP", "FAST", "FILENAME_ONLY"] = "FAST", + max_loops: int = 10, + max_token_budget: int = 128000, max_depth: Optional[int] = 5, - top_k_files: int = 3, + top_k_files: int = 5, enable_dir_scan: bool = False, include: Optional[List[str]] = None, exclude: Optional[List[str]] = None, + return_context: bool = False, + spec_stale_hours: float = 72.0, + chat_history: Optional[List[Dict[str, str]]] = None, llm_fallback: bool = False, - ) -> Tuple[str, Optional[KnowledgeCluster], SearchContext]: - """Greedy search: 2-3 LLM calls, single best file, focused evidence. - - Two-level keyword cascade extracted in one LLM call: - primary (compound phrase) is tried first; if it misses, fallback - (atomic terms) is tried. When ``enable_dir_scan`` is True and - paths contain directories, a directory scan runs concurrently with - keyword extraction and acts as a fallback retrieval path. + ) -> Union[str, SearchContext, List[Dict[str, Any]]]: + """Perform intelligent search with multi-mode support. + + Modes: + +--------------+-------------------+-------------------------------------------+ + | Mode | Speed / LLM Calls | Description | + +--------------+-------------------+-------------------------------------------+ + | FILENAME_ONLY| Very Fast / 0 | Pattern-based file discovery, no LLM. | + | FAST | 1-5s / 0-2 | Greedy: cluster reuse or keyword search | + | | | → best file → answer. Early termination. | + | DEEP | 5-30s / 4-6 | Parallel multi-path retrieval + ReAct | + | | | refinement with Monte-Carlo evidence. | + +--------------+-------------------+-------------------------------------------+ + + FAST architecture (greedy early-termination): + + ┌──────────────────────────────────────────────────────────┐ + │ Step 0 Cluster reuse check (instant short-circuit) │ + ├──────────────────────────────────────────────────────────┤ + │ Step 1 LLM query analysis → keywords + file hints │ + │ (single call, stream=False) │ + ├──────────────────────────────────────────────────────────┤ + │ Step 2 rga keyword search → ranked file hits + snippets │ + │ (no LLM, greedy: take first good results) │ + ├──────────────────────────────────────────────────────────┤ + │ Step 3 Read top file(s) content │ + │ (no LLM, early termination at top_k_files) │ + ├──────────────────────────────────────────────────────────┤ + │ Step 4 LLM answer synthesis from evidence │ + └──────────────────────────────────────────────────────────┘ + + DEEP architecture (phases execute as parallel as possible): + + ┌──────────────────────────────────────────────────────────┐ + │ Phase 0a Direct document analysis (intent-gated, │ + │ short-circuit if query is doc-level operation) │ + ├──────────────────────────────────────────────────────────┤ + │ Phase 0 Cluster reuse check (instant, short-circuit) │ + ├──────────────────────────────────────────────────────────┤ + │ Phase 1 Parallel probing (all concurrent): │ + │ ├─ LLM keyword extraction │ + │ ├─ DirectoryScanner.scan() (filesystem only, fast) │ + │ ├─ Knowledge cache similarity search │ + │ └─ Spec-path cache load │ + ├──────────────────────────────────────────────────────────┤ + │ Phase 2 Parallel retrieval (depends on Phase 1): │ + │ ├─ keyword_search per extracted keyword (concurrent rga)│ + │ └─ DirectoryScanner.rank() (LLM ranks candidates) │ + ├──────────────────────────────────────────────────────────┤ + │ Phase 3 Merge + evidence assembly: │ + │ └─ knowledge_base.build() (parallel per-file Monte │ + │ Carlo evidence sampling) │ + ├──────────────────────────────────────────────────────────┤ + │ Phase 4 Summary / ReAct refinement: │ + │ └─ If evidence sufficient → LLM summary │ + │ Else → ReAct loop for adaptive follow-up │ + ├──────────────────────────────────────────────────────────┤ + │ Phase 5 Persistence (concurrent, awaited): │ + │ ├─ Save cluster + embeddings │ + │ └─ Save spec-path cache │ + └──────────────────────────────────────────────────────────┘ + + Args: + query: User's search query. + paths: Directories / files to search. Falls back to + ``self.paths`` or the current working directory. + mode: Search mode — ``"DEEP"``, ``"FAST"``, or ``"FILENAME_ONLY"``. + max_loops: Maximum ReAct iterations (DEEP mode, default: 10). + max_token_budget: LLM token budget (DEEP mode, default: 128000). + max_depth: Maximum directory depth for file search (default: 5). + Used in both FILENAME_ONLY and DEEP modes. + top_k_files: Max files for evidence extraction (default: 5). + enable_dir_scan: Enable directory scanning (FAST and DEEP modes). + include: File glob patterns to include (e.g. ``["*.py", "*.md"]``). + Used in both FILENAME_ONLY and DEEP modes. + exclude: File glob patterns to exclude (e.g. ``["*.log"]``). + Used in both FILENAME_ONLY and DEEP modes. + return_context: If True, return a ``SearchContext`` object + that carries ``answer``, ``cluster`` (KnowledgeCluster), + and full pipeline telemetry (LLM usage, files read, etc.). + spec_stale_hours: Hours before spec cache is stale (default: 72). + chat_history: Optional list of chat messages for context (DEEP mode). + llm_fallback: When True, if no relevant documents are found, + the LLM will attempt to answer the query from its own + knowledge. Default False. + + Returns: + - ``str``: Answer summary (default). + - ``SearchContext``: If *return_context* — contains ``answer``, + ``cluster``, and telemetry in a single object. + - ``List[Dict]``: File matches in FILENAME_ONLY mode. + """ + paths = self.validate_search_paths( + self._resolve_paths(paths), + ) + if not paths: + msg = "No valid search paths remain after validation." + _loguru_logger.warning(msg) + if return_context: + ctx = SearchContext() + ctx.answer = msg + return ctx + return msg + + await self._logger.info(f"[SearchConfig] PURE_TREE_SEARCH={'enabled' if _PURE_TREE_SEARCH else 'disabled'}") + + # ---- Chat intent short-circuit (rule-based, no LLM cost) ---- + if mode != "FILENAME_ONLY" and self._is_chat_query(query): + answer, cluster, ctx = await self._respond_chat(query, chat_history=chat_history) + if return_context: + ctx.answer = answer + return ctx + return answer + + # ---- FILENAME_ONLY: pattern-based file discovery, no LLM ---- + if mode == "FILENAME_ONLY": + results = await self._search_by_filename( + query=query, paths=paths, max_depth=max_depth, + include=include, exclude=exclude, top_k=top_k_files, + ) + if not results: + msg = f"No files found matching query: '{query}'" + await self._logger.warning(msg) + return msg + await self._logger.success(f"Retrieved {len(results)} matching files") + return results + + # ---- FAST / DEEP → both produce (answer, cluster, context) ---- + if mode == "FAST": + answer, cluster, context = await self._search_fast( + query=query, paths=paths, max_depth=max_depth, + top_k_files=top_k_files, enable_dir_scan=enable_dir_scan, + include=include, exclude=exclude, + llm_fallback=llm_fallback, + ) + else: + answer, cluster, context = await self._search_deep( + query=query, paths=paths, + max_loops=max_loops, max_token_budget=max_token_budget, + max_depth=max_depth, top_k_files=top_k_files, + enable_dir_scan=enable_dir_scan, + include=include, exclude=exclude, + spec_stale_hours=spec_stale_hours, + llm_fallback=llm_fallback, + ) + + # ---- Unified return wrapping ---- + if return_context: + prefix = "FS" if mode == "FAST" else "DS" + context.answer = answer + if (answer or "").strip().lower() == _NO_RESULTS_MESSAGE.lower(): + context.cluster = cluster + return context + # Use read_file_ids from context if available, otherwise empty + fallback_files = list(context.read_file_ids) if context.read_file_ids else None + context.cluster = cluster or self._make_answer_cluster( + query, answer, prefix, file_paths=fallback_files, + ) + return context + return answer + + # ------------------------------------------------------------------ + # DEEP mode — parallel multi-path retrieval with ReAct fallback + # ------------------------------------------------------------------ + + async def _search_deep( + self, + query: str, + paths: List[str], + *, + max_loops: int = 10, + max_token_budget: int = 128000, + max_depth: Optional[int] = 5, + top_k_files: int = 5, + enable_dir_scan: bool = False, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + spec_stale_hours: float = 72.0, + llm_fallback: bool = False, + ) -> Tuple[str, Optional[KnowledgeCluster], SearchContext]: + """Parallel multi-path retrieval pipeline (Phases 0a–5). + + Returns: + ``(answer, cluster, context)`` tuple. + """ + context = SearchContext( + max_token_budget=max_token_budget, + max_loops=max_loops, + ) + _llm_usage_start = len(self.llm_usages) + + # --- Adaptive compile artifact detection (shared with FAST) --- + _scope = _PathScope(paths) + artifacts = self._detect_compile_artifacts(paths) + + # ============================================================== + # Phase 0a: Direct document analysis (intent-gated short-circuit) + # ============================================================== + direct = await self._try_direct_doc_analysis(query, paths) + if direct is not None: + return direct, self._make_answer_cluster(query, direct, "DQ", file_paths=paths), context + + # ============================================================== + # Phase 0: Cluster reuse (instant short-circuit) + # When reuse_knowledge=True and a similar cluster is found, we + # return here — Phase 5 (Persistence) is not executed for that path. + # ============================================================== + reused = await self._try_reuse_cluster(query, paths) + if reused is not None: + return self._enrich_reused_content(reused), reused, context + + # P2: gradient reuse — extract hints from moderately similar clusters + soft_hit = await self._try_soft_reuse(query, paths) + + await self._logger.info(f"[search] Starting multi-path retrieval for: '{query[:80]}'") + + # ============================================================== + # Phase 1: Parallel probing — five paths fire concurrently + # ============================================================== + await self._logger.info("[Phase 1] Parallel probing: keywords + dir_scan + knowledge + spec_cache + tree_index") + context.increment_loop() + + phase1_results = await asyncio.gather( + self._probe_keywords(query), + self._probe_dir_scan(paths, enable_dir_scan), + self._probe_knowledge_cache(query), + self._load_spec_context(paths, stale_hours=spec_stale_hours), + self._probe_tree_index(query), + self._probe_compile_hints([query], scope=_scope), # query-level hints; keyword-level runs post-Phase 1 + self._probe_summary_index(query, artifacts, scope=_scope), # GAP 2: zero-LLM BM25 + self._probe_catalog_for_deep(query, artifacts), # GAP 4: zero-LLM keyword overlap + return_exceptions=True, + ) + + kw_result = phase1_results[0] if not isinstance(phase1_results[0], Exception) else ({}, []) + scan_result = phase1_results[1] if not isinstance(phase1_results[1], Exception) else None + knowledge_probe = phase1_results[2] if not isinstance(phase1_results[2], Exception) else KnowledgeProbeResult([], [], "") + spec_context = phase1_results[3] if not isinstance(phase1_results[3], Exception) else "" + tree_hits = phase1_results[4] if not isinstance(phase1_results[4], Exception) else [] + compile_hints = phase1_results[5] if not isinstance(phase1_results[5], Exception) else CompileHints([], []) + summary_index_hits = phase1_results[6] if not isinstance(phase1_results[6], Exception) else [] + catalog_deep_hits = phase1_results[7] if not isinstance(phase1_results[7], Exception) else [] + + for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache", "tree_index", "compile_hints", "summary_index", "catalog_deep"]): + if isinstance(phase1_results[i], Exception): + await self._logger.warning(f"[Phase 1] {label} probe failed: {phase1_results[i]}") + + # Backwards compat: knowledge_probe may be a plain list from old code paths + if isinstance(knowledge_probe, list): + knowledge_probe = KnowledgeProbeResult(file_paths=knowledge_probe, extra_keywords=[], background_context="") + + query_keywords, initial_keywords = kw_result if isinstance(kw_result, tuple) else ({}, []) + + # P2: inject soft-hit patterns into keywords + if soft_hit: + for p in soft_hit.patterns: + if p not in initial_keywords: + initial_keywords.append(p) + if p not in query_keywords: + query_keywords[p] = 0.6 + + # P3: inject extra keywords from structured knowledge probe + for kw in knowledge_probe.extra_keywords: + if kw not in initial_keywords: + initial_keywords.append(kw) + if kw not in query_keywords: + query_keywords[kw] = 0.5 + + # P2 + P3: append background context for Phase 4 LLM prompt + if soft_hit and soft_hit.context_summary: + spec_context = f"{spec_context}\n\n{soft_hit.context_summary}" if spec_context else soft_hit.context_summary + if knowledge_probe.background_context: + spec_context = f"{spec_context}\n\n{knowledge_probe.background_context}" if spec_context else knowledge_probe.background_context + + await self._logger.info( + f"[Phase 1] Results: keywords={len(initial_keywords)}, " + f"dir_scan={'OK' if scan_result else 'N/A'}, " + f"knowledge_files={len(knowledge_probe.file_paths)}, " + f"tree_hits={len(tree_hits)}, " + f"compile_hints={len(compile_hints.file_paths)}, " + f"summary_index={len(summary_index_hits)}, " + f"catalog_deep={len(catalog_deep_hits)}, " + f"soft_hit={'YES' if soft_hit else 'NO'}, " + f"spec_cache={'YES' if spec_context else 'NO'}" + ) + + # ============================================================== + # Phase 2: Parallel retrieval — keyword search + dir_scan rank + # ============================================================== + keyword_files: List[str] = [] + dir_scan_files: List[str] = [] + + if _PURE_TREE_SEARCH: + # Pure tree search mode: skip rga and dir_scan, rely solely on tree hits + await self._logger.info("[Phase 2:PureTree] Skipping rga keyword search and dir_scan") + context.increment_loop() + else: + await self._logger.info("[Phase 2] Parallel retrieval: rga keyword search + dir_scan LLM rank") + context.increment_loop() + + phase2_tasks = [] + + if initial_keywords: + phase2_tasks.append( + self._retrieve_by_keywords( + initial_keywords, paths, + max_depth=max_depth, include=include, exclude=exclude, + ) + ) + else: + phase2_tasks.append(self._async_noop([])) + + if scan_result is not None and enable_dir_scan: + phase2_tasks.append( + self._rank_dir_scan_candidates(query, scan_result) + ) + else: + phase2_tasks.append(self._async_noop([])) + + phase2_results = await asyncio.gather(*phase2_tasks, return_exceptions=True) + + keyword_files = phase2_results[0] if not isinstance(phase2_results[0], Exception) else [] + dir_scan_files = phase2_results[1] if not isinstance(phase2_results[1], Exception) else [] + + for i, label in enumerate(["keyword_search", "dir_scan_rank"]): + if isinstance(phase2_results[i], Exception): + await self._logger.warning(f"[Phase 2] {label} failed: {phase2_results[i]}") + + await self._logger.info( + f"[Phase 2] Results: keyword_files={len(keyword_files)}, " + f"dir_scan_files={len(dir_scan_files)}" + ) + + # ============================================================== + # Phase 3: Query analysis + file selection + # ============================================================== + context.increment_loop() + _query_complexity, _query_intent = await self._classify_query_intent(query) + data_reqs = await self._analyze_data_requirements(query, _query_intent) + context.increment_loop() + + await self._logger.info( + f"[Phase 3] Query: complexity={_query_complexity}, " + f"intent={_query_intent}, " + f"data_points={len(data_reqs.data_points)}, " + f"formula={data_reqs.formula or 'N/A'}" + ) + + extra_knowledge_files = knowledge_probe.file_paths + if soft_hit: + extra_knowledge_files = soft_hit.file_paths + extra_knowledge_files + + merged_files = self._merge_file_paths( + keyword_files=list(tree_hits) + catalog_deep_hits + compile_hints.file_paths + summary_index_hits + keyword_files, + dir_scan_files=dir_scan_files, + knowledge_hits=extra_knowledge_files, + ) + target_files = self._select_target_files(merged_files, _scope, artifacts) + + await self._logger.info( + f"[Phase 3] Merged {len(merged_files)} files, " + f"target {len(target_files)} for agentic retrieval" + ) + + # ============================================================== + # Phase 4: Agentic retrieval loop + # ============================================================== + retrieval = await self._agentic_retrieve( + query, data_reqs, target_files, context, + ) + + await self._logger.info( + f"[Phase 4] Retrieval: {retrieval.rounds_used} rounds, " + f"complete={retrieval.is_complete}, " + f"{sum(len(ps) for ps in retrieval.pages_extracted.values())} pages" + ) + + # ============================================================== + # Phase 4.5: Synthesis + # ============================================================== + answer, should_save, cluster = await self._synthesize_from_retrieval( + query, _query_intent, retrieval, merged_files, + formula=data_reqs.formula, + ) + + # ============================================================== + # Phase 4.75: Computation verification + # ============================================================== + if answer and answer != _NO_RESULTS_MESSAGE and _query_intent == "computation": + answer, was_corrected = await self._verify_computation(query, answer) + if was_corrected: + _, should_save, _ = self._parse_summary_response(answer) + + # Sync LLM token accounting into context + new_usages = self.llm_usages[_llm_usage_start:] + for usage in new_usages: + if usage and isinstance(usage, dict): + total_tok = usage.get("total_tokens", 0) + if total_tok == 0: + total_tok = usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0) + context.add_llm_tokens(total_tok, usage=usage) + + # ============================================================== + # Phase 5: Persistence (quality-gated) + # Skipped when Phase 4 quality check says the answer is low-quality + # or when Phase 0 reused a cluster (early-returned above). + # ============================================================== + phase5_tasks = [] + if cluster and should_save: + self._add_query_to_cluster(cluster, query) + phase5_tasks.append(self._save_cluster_with_embedding(cluster)) + elif not should_save: + await self._logger.info("[Phase 5] Quality gate: low-quality answer, skipping cluster save") + cluster = None + phase5_tasks.append(self._save_spec_context(paths, context, scan_result=scan_result)) + results = await asyncio.gather(*phase5_tasks, return_exceptions=True) + for r in results: + if isinstance(r, Exception): + _loguru_logger.warning(f"[Phase 5] Persistence task failed: {r}") + + await self._logger.success(f"[search] Complete: {context.summary()}") + return answer, cluster, context + + # ------------------------------------------------------------------ + # Phase 0a: Direct document analysis (intent-gated) + # ------------------------------------------------------------------ + + async def _try_direct_doc_analysis( + self, + query: str, + paths: List[str], + ) -> Optional[str]: + """Short-circuit for document-level queries (e.g. "请总结这篇文档"). + + Uses the LLM to classify query intent (language-agnostic). When + a whole-document operation is detected **and** suitable files exist + in *paths*, their content is fed directly to the LLM — bypassing + the heavyweight keyword / dir-scan / evidence pipeline. + + Returns: + LLM answer string, or None if the short-circuit does not apply. + """ + from sirchmunk.doc_qa import ( + detect_doc_intent, + collect_doc_files, + analyse_documents, + ) + + # Step 1: file gate — skip early if paths contain no loadable docs + doc_files = collect_doc_files(paths) + if not doc_files: + return None + + # Step 2: LLM intent classification (cheap, stream=False) + operation = await detect_doc_intent(query, self.llm, self.llm_usages) + if operation is None: + return None + + # Computation/comparison queries need the full evidence pipeline + if re.search( + r'\b(?:ratio|margin|growth.?rate|turnover|coverage' + r'|what is (?:the )?fy\d|calculate|compute' + r'|improv(?:ing|ed)|declin(?:ing|ed)' + r'|which .{0,30}(?:best|worst|most|least|highest|lowest))\b', + query, re.IGNORECASE, + ): + return None + + filenames = ", ".join(Path(d.path).name for d in doc_files) + await self._logger.info( + f"[DocQA] Intent '{operation}' detected — " + f"loading {len(doc_files)} file(s) for direct analysis: {filenames}" + ) + + # Step 3: for summary operations, use the chunked summarizer + # with optional smart dir scanning; for other operations, use the + # general analyser. + if operation in ("summarize", "summary", "extract"): + scan_result = None + if self._has_directory_paths(paths): + scan_result = await self._probe_dir_scan(paths, max_files=300) + answer = await self._summarize_documents( + query, paths, scan_result=scan_result, + ) + else: + answer = await analyse_documents( + query=query, + doc_files=doc_files, + llm=self.llm, + llm_usages=self.llm_usages, + ) + + if answer: + await self._logger.success("[DocQA] Direct document analysis complete") + return answer + + # ------------------------------------------------------------------ + # Chat intent detection — short-circuit for non-search queries + # ------------------------------------------------------------------ + + @staticmethod + def _is_chat_query(query: str) -> bool: + """Return True for obvious conversational queries (rule-based, no LLM).""" + return bool(_CHAT_QUERY_RE.match(query.strip())) + + async def _respond_chat( + self, + query: str, + context: Optional[SearchContext] = None, + *, + chat_history: Optional[List[Dict[str, str]]] = None, + ) -> Tuple[str, Optional[KnowledgeCluster], SearchContext]: + """Generate a direct conversational response (single LLM call, no retrieval).""" + await self._logger.info( + f"[search] Chat intent detected — responding directly: '{query[:60]}'" + ) + ctx = context or SearchContext() + messages = [ + {"role": "system", "content": _CHAT_RESPONSE_SYSTEM}, + *(chat_history or []), + {"role": "user", "content": query}, + ] + resp = await self.llm.achat(messages=messages, stream=False) + self.llm_usages.append(resp.usage) + if resp.usage and isinstance(resp.usage, dict): + ctx.add_llm_tokens( + resp.usage.get("total_tokens", 0), usage=resp.usage, + ) + return resp.content or "", None, ctx + + # ------------------------------------------------------------------ + # Document summarization — shared by FAST & DEEP summary intent + # ------------------------------------------------------------------ + + _SUMMARY_MAX_CONTEXT_CHARS = 100_000 + _SUMMARY_CHUNK_CHARS = 50_000 + _SUMMARY_MAX_FILE_SIZE = 200 * 1024 * 1024 # 200 MB — sampling handles large files + + async def _summarize_documents( + self, + query: str, + paths: List[str], + *, + top_k_files: int = 5, + scan_result=None, + ) -> Optional[str]: + """Summarize documents from *paths* with smart content sampling. + + When *scan_result* (from a prior directory scan) is provided, the + LLM ranks candidates first so only the most relevant files are + summarized. Otherwise falls back to ``collect_doc_files``. + + Small files are loaded in full; large files are sampled (head + mid + + tail). When the total content exceeds the LLM context budget, the + documents are processed in chunks — each chunk is summarized + independently, then the partial summaries are merged in a final pass. + + Returns: + Summary string, or ``None`` if no documents could be loaded. + """ + from sirchmunk.doc_qa import collect_doc_files, _extract_text, _sample_text + + summary_paths: Optional[List[str]] = None + + # When a scan result is available, use LLM ranking to pick candidates + if scan_result is not None: + ranked = await self._rank_dir_scan_candidates( + query, scan_result, + top_k=top_k_files * 2, + include_medium=True, + ) + if ranked: + summary_paths = ranked[:top_k_files] + await self._logger.info( + f"[Summary] Dir scan selected {len(summary_paths)} relevant file(s)" + ) + + doc_files = collect_doc_files( + summary_paths or paths, + max_files=top_k_files, + max_file_size=self._SUMMARY_MAX_FILE_SIZE, + ) + if not doc_files: + await self._logger.warning( + f"[Summary] No loadable documents found in paths: {paths}" + ) + return None + + doc_texts: List[Tuple[str, str]] = [] + total_chars = 0 + for df in doc_files: + text = await _extract_text(df) + if text: + fname = Path(df.path).name + doc_texts.append((fname, text)) + total_chars += len(text) + else: + await self._logger.warning( + f"[Summary] Text extraction failed for: {Path(df.path).name}" + ) + + if not doc_texts: + await self._logger.warning("[Summary] No text could be extracted from collected documents") + return None + + await self._logger.info( + f"[Summary] Loaded {len(doc_texts)} doc(s), " + f"total {total_chars} chars" + ) + + needs_sampling = total_chars > self._SUMMARY_MAX_CONTEXT_CHARS + per_file_budget = ( + self._SUMMARY_MAX_CONTEXT_CHARS // len(doc_texts) + if needs_sampling else 0 + ) + + parts: List[str] = [] + for fname, text in doc_texts: + content = _sample_text(text, per_file_budget) if needs_sampling else text + parts.append(f"#### File: {fname}\n```\n{content}\n```") + + combined = "\n\n".join(parts) + + if len(combined) <= self._SUMMARY_CHUNK_CHARS: + return await self._llm_summarize_docs(combined, query) + + return await self._llm_chunked_summarize(combined, query) + + async def _llm_summarize_docs(self, documents: str, query: str) -> str: + """Single-pass LLM summarization.""" + prompt = DOC_SUMMARY.format(documents=documents, user_input=query) + resp = await self.llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=True, + ) + self.llm_usages.append(resp.usage) + return resp.content or "" + + async def _llm_chunked_summarize(self, combined: str, query: str) -> str: + """Multi-pass chunked summarization for large content.""" + chunk_size = self._SUMMARY_CHUNK_CHARS + chunks = [ + combined[i:i + chunk_size] + for i in range(0, len(combined), chunk_size) + ] + await self._logger.info( + f"[Summary] Content exceeds single-pass limit — " + f"splitting into {len(chunks)} chunk(s)" + ) + + partial_summaries: List[str] = [] + for idx, chunk in enumerate(chunks, 1): + await self._logger.info(f"[Summary] Summarizing chunk {idx}/{len(chunks)}") + prompt = DOC_CHUNK_SUMMARY.format(chunk=chunk, user_input=query) + resp = await self.llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=True, + ) + self.llm_usages.append(resp.usage) + if resp.content: + partial_summaries.append(resp.content) + + if not partial_summaries: + return "" + if len(partial_summaries) == 1: + return partial_summaries[0] + + merged_input = "\n\n---\n\n".join( + f"**Part {i}**\n{s}" for i, s in enumerate(partial_summaries, 1) + ) + prompt = DOC_MERGE_SUMMARIES.format(summaries=merged_input, user_input=query) + resp = await self.llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=True, + ) + self.llm_usages.append(resp.usage) + return resp.content or "" + + # ------------------------------------------------------------------ + # FAST mode — greedy search with early termination + # ------------------------------------------------------------------ + + _FAST_TEXT_EXTENSIONS = { + ".txt", ".md", ".rst", ".csv", ".log", ".tsv", + ".py", ".js", ".ts", ".json", ".yaml", ".yml", ".xml", + ".html", ".htm", ".sh", ".toml", ".cfg", ".ini", ".conf", + ".css", ".bash", ".java", ".c", ".cpp", ".h", ".go", ".rs", + } + _FAST_CONTEXT_WINDOW = 30 # ± lines around each grep hit + _FAST_MAX_EVIDENCE_CHARS = 40_000 + _FAST_SMALL_FILE_THRESHOLD = 100_000 # 100K chars - read full file instead of grep sampling + + # --- Wiki-enhanced ranking constants --- + _WIKI_BLEND_ALPHA = 0.85 + """TF-IDF weight in the hybrid score; Wiki weight = 1 - alpha.""" + _WIKI_MAX_SCORE = 10.0 + """Upper bound for the wiki relevance score.""" + _WIKI_CATALOG_KEYWORD_OVERLAP_MAX = 5.0 + """Maximum sub-score for catalog summary keyword overlap.""" + _WIKI_TREE_AVAILABILITY_BONUS = 0.5 + """Bonus for files that have a compiled tree index (weak signal).""" + _WIKI_CATALOG_PRESENCE_FULL = 2.0 + """Catalog presence bonus for summaries > 100 chars.""" + _WIKI_CATALOG_PRESENCE_MEDIUM = 1.5 + """Catalog presence bonus for summaries > 30 chars (must be < FULL).""" + _WIKI_CATALOG_PRESENCE_MINIMAL = 1.0 + """Catalog presence bonus for summaries > 0 chars.""" + _TREE_CACHE_SCAN_LIMIT = 200 + """Max tree JSON files to parse during artifact detection.""" + _CATALOG_LISTING_MAX_ENTRIES = 20 + """Max catalog entries in the enriched listing for Step 1.""" + _ENABLE_EMBEDDING_FALLBACK: bool = True + """Enable embedding + BM25 hybrid fallback when rga returns zero results.""" + _CATALOG_KEYWORD_MIN_LEN = 2 + """Minimum character length for a catalog keyword token.""" + _CATALOG_KEYWORD_MAX_LEN = 20 + """Maximum character length for a catalog keyword token.""" + _CATALOG_SUMMARY_TRUNCATE = 200 + """Max chars of catalog summary shown in the listing.""" + _SUMMARY_INDEX_TOP_K = 3 + """Maximum files returned by proactive summary index BM25 probe.""" + _DEEP_CATALOG_TOP_K = 3 + """Maximum files returned by catalog keyword-overlap probe in DEEP mode.""" + + # --- Tree-guided sampling constants --- + _TREE_SAMPLE_MAX_SECTIONS = 8 + """Max tree sections to include per file in tree-guided sampling.""" + _TREE_SAMPLE_SECTION_MAX_CHARS = 3000 + """Max chars per tree section.""" + _TREE_SAMPLE_RGA_SUPPLEMENT = True + """Whether to append rga evidence after tree sections as supplementary context.""" + _TREE_ROOT_HINTS_MAX_FILES = 10 + """Maximum number of tree roots to include in FAST Step 1 hints.""" + _DEEP_PRE_NAV_MAX_FILES = 3 + """Maximum number of tree files to pre-navigate in DEEP Phase 2.5.""" + _FAST_TREE_PROBE_MAX_FILES = 2 + """Maximum files returned by active tree probing in FAST mode.""" + _DEEP_TREE_PROBE_MAX_FILES = 3 + """Maximum files returned by tree index probing in DEEP mode.""" + _TREE_ROOT_HINT_TRUNCATE = 150 + """Max chars of tree root summary in Step 1 structure hints.""" + _CHAR_RANGE_MAX_SPAN_RATIO: float = 0.8 + """char_range spanning more than this ratio of the document is treated as invalid.""" + + # --- Tree probe / RGA fusion --- + _TREE_PROBE_RANKING_BOOST: float = 3.0 + """Score boost (0-10 scale) for files selected by LLM tree probing.""" + + # --- Hierarchical file selection for large tree pools --- + _TREE_PREFILTER_THRESHOLD: int = 15 + """Tree pool size above which rule-based pre-filtering is applied.""" + _TREE_PREFILTER_MAX_CANDIDATES: int = 10 + """Maximum candidate trees forwarded to the LLM after pre-filtering.""" + _TREE_PREFILTER_MIN_SCORE: float = 0.5 + """Minimum relevance score for a tree to survive pre-filtering.""" + + # --- Tree navigation --- + _TREE_NAV_MAX_RESULTS: int = 8 + """Primary max_results for LLM-driven tree navigation.""" + _NAV_RETRY_MIN_EVIDENCE_CHARS: int = 200 + """Evidence below this length triggers a retry with expanded results.""" + _NAV_RETRY_EXPANDED_RESULTS: int = 12 + """Expanded max_results for retry navigation pass.""" + + _CHAR_RANGE_MIN_SPAN: int = 200 + """Minimum char_range span to trust as substantive content. + + Nodes whose char_range covers fewer characters than this threshold + (e.g. a TOC entry that only records the section title) are demoted + to page-level extraction when a valid page_range is available. + """ + + _NAV_COMPLEMENT_MIN_COMPONENTS: int = 2 + """Minimum query decomposition components to trigger complementary navigation.""" + + _NAV_PAGE_MARGIN: int = 1 + """Extra pages to extract on each side of a leaf's page_range.""" + + _NAV_REF_PAGE_MAX: int = 5 + """Maximum referenced-but-uncovered pages to extract as gap-fill.""" + + # --- Table evidence budgets --- + _TABLE_EVIDENCE_DEFAULT_CHARS: int = 20_000 + """Default max_chars for _format_table_evidence.""" + _TABLE_EVIDENCE_PER_RANGE_CHARS: int = 8_000 + """Max chars for per-page-range table supplement in tree nav.""" + _TABLE_EVIDENCE_STANDALONE_CHARS: int = 20_000 + """Max chars for standalone table digest fallback when tree nav evidence is thin.""" + _TABLE_CROSS_SECTION_CHARS: int = 6_000 + """Max chars for cross-section table supplement drawn from pages outside + the navigated leaf ranges. Ensures data-dense tables in distant + document sections (e.g. financial statements when leaves are in + management discussion) are included.""" + _TABLE_EVIDENCE_NAV_OVERLAP_CHARS: int = 8_000 + """Reduced table evidence budget for files that are already receiving + parallel tree navigation. Since tree_ev will provide targeted evidence, + the RGA path uses a smaller budget to supply incremental tables, + leaving room for more diverse evidence.""" + _DEEP_CROSS_SECTION_MIN_EVIDENCE: int = 8_000 + """Cross-section table supplement is skipped when existing tree-nav + evidence already exceeds this threshold (chars), preventing overload.""" + + # --- Self-correction expanded sampling --- + _SELF_CORRECT_EXPANDED_NAV_RESULTS: int = 10 + """Expanded tree navigation leaf count for same-file re-sampling (default nav uses 5).""" + _SELF_CORRECT_EXPANDED_SECTIONS: int = 8 + """Expanded tree sample sections for same-file re-sampling (default uses 5).""" + + # --- Deep Structured Reasoning --- + _DEEP_SECTION_MAP_MAX_DEPTH: int = 3 + """Maximum tree depth for section map construction (top-N layers).""" + _DEEP_MAX_EXTRACT_PAGES: int = 12 + """Maximum pages to extract per file in targeted page extraction.""" + _DEEP_STRUCTURED_MAX_CHARS: int = 30_000 + """Maximum character budget for structured evidence per file.""" + _DEEP_MAX_RECOVERY_ROUNDS: int = 3 + """Maximum rounds of missing-data recovery before final answer.""" + _DEEP_STRUCTURED_MAX_FILES: int = 3 + """Maximum files to process through structured reasoning pipeline.""" + + # --- Agentic retrieval --- + _AGENTIC_MAX_ROUNDS: int = 3 + """Maximum retrieval rounds in the agentic loop.""" + _AGENTIC_MAX_PAGES_PER_ROUND: int = 8 + """Maximum new pages to extract per round per file.""" + _AGENTIC_MAX_TOTAL_PAGES: int = 20 + """Maximum total pages across all rounds.""" + _AGENTIC_MAX_FILES: int = 3 + """Maximum files to process through agentic retrieval.""" + _AGENTIC_SECTION_MAP_DEPTH: int = 8 + """Section map depth for agentic page selection.""" + _AGENTIC_EVIDENCE_MAX_CHARS: int = 40_000 + """Maximum evidence characters to feed to synthesis prompt.""" + _SHORT_DOC_THRESHOLD: int = 30 + """Documents with this many pages or fewer are extracted in full.""" + + # --- Evidence acceptance thresholds --- + _EVIDENCE_MIN_ACCEPT_LENGTH: int = 800 + """Minimum evidence character length for heuristic override.""" + _EVIDENCE_KEYWORD_COVERAGE_THRESHOLD: float = 0.3 + """Minimum keyword coverage ratio for heuristic override.""" + + _NUMERIC_INTENT_KEYWORDS: frozenset = frozenset({ + "revenue", "margin", "ratio", "ebitda", "income", "profit", "loss", + "cash", "debt", "equity", "eps", "dpo", "growth", "rate", + "percentage", "amount", "total", "net", "gross", "cost", "expense", + "sales", "fy", "fiscal", + }) + """Keywords indicating numeric/financial intent in a query.""" + + _LLM_FALLBACK_EVIDENCE = ( + "[No relevant documents found]\n\n" + "The search did not find relevant content in the available documents. " + "Please answer the user's question based on your own knowledge. " + "Clearly indicate that this answer is from LLM knowledge, " + "not from retrieved documents." + ) + + async def _search_fast( + self, + query: str, + paths: List[str], + *, + max_depth: Optional[int] = 5, + top_k_files: int = 3, + enable_dir_scan: bool = False, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + llm_fallback: bool = False, + ) -> Tuple[str, Optional[KnowledgeCluster], SearchContext]: + """Greedy search: 2-3 LLM calls, single best file, focused evidence. + + Two-level keyword cascade extracted in one LLM call: + primary (compound phrase) is tried first; if it misses, fallback + (atomic terms) is tried. When ``enable_dir_scan`` is True and + paths contain directories, a directory scan runs concurrently with + keyword extraction and acts as a fallback retrieval path. + + Returns: + ``(answer, cluster, context)`` — same triple as ``_search_deep`` + so the caller can handle both modes uniformly. + """ + context = SearchContext() + await self._logger.info(f"[FAST] Starting greedy search for: '{query[:80]}'") + + # Reset per-session tree navigation cache + self._tree_nav_cache = _TreeNavCache() + + # --- Adaptive compile artifact detection (one-shot, zero LLM) --- + _scope = _PathScope(paths) + artifacts = self._detect_compile_artifacts(paths) + if artifacts.catalog or artifacts.tree_available_paths: + await self._logger.info( + f"[FAST:Artifacts] catalog={'yes' if artifacts.catalog else 'no'} " + f"({len(artifacts.catalog) if artifacts.catalog else 0} docs), " + f"trees={len(artifacts.tree_available_paths)}" + ) + + # ============================================================== + # Step 0: Cluster reuse — instant short-circuit (no LLM cost) + # When reuse succeeds we return here; no persistence step runs. + # ============================================================== + reused = await self._try_reuse_cluster(query, paths) + if reused is not None: + await self._logger.success("[FAST] Reused cached knowledge cluster") + return self._enrich_reused_content(reused), reused, context + + # P2: gradient reuse — structured hints from moderately similar clusters + soft_hit = await self._try_soft_reuse(query, paths) + + # ============================================================== + # Step 1: Fused LLM query analysis + document routing + # When a compiled document catalog exists, the LLM sees all + # document summaries and selects the most relevant ones in the + # same call that extracts keywords (zero extra LLM cost). + # ============================================================== + catalog = artifacts.catalog + catalog_routed_files: List[str] = [] + catalog_confidence: str = "low" + + # Build tree root hints for enhanced query analysis + tree_hints = "" + if artifacts and artifacts.tree_available_paths: + tree_hints = self._build_tree_root_hints(artifacts) + + if catalog: + listing = self._build_enriched_catalog_listing(catalog) + prompt = FAST_QUERY_ANALYSIS_WITH_CATALOG.format( + user_input=query, document_listing=listing, + ) + else: + prompt = FAST_QUERY_ANALYSIS.format(user_input=query) + + # Append tree structure hints to the prompt when available + if tree_hints: + prompt = prompt + tree_hints + + # Step 1 LLM call + compile hints + tree probe run in parallel + # (GAP 3: hints前置化, GAP 1: 树导航主动化) + _step1_llm_task = self.llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=False, + ) + _compile_hints_task = self._probe_compile_hints([query], scope=_scope) + _tree_probe_task = self._probe_tree_for_fast(query, artifacts) + + _parallel_results = await asyncio.gather( + _step1_llm_task, _compile_hints_task, _tree_probe_task, + return_exceptions=True, + ) + resp = _parallel_results[0] + _early_compile_hints = _parallel_results[1] + _tree_probed_files = _parallel_results[2] + + if isinstance(resp, Exception): + await self._logger.warning(f"[FAST:Step1] LLM call failed: {resp}") + return f"Search analysis failed: {resp}", None, context + if isinstance(_early_compile_hints, Exception): + await self._logger.warning(f"[FAST:Step1] Compile hints pre-fetch failed: {_early_compile_hints}") + _early_compile_hints = CompileHints([], []) + if isinstance(_tree_probed_files, Exception): + await self._logger.warning(f"[FAST:Step1] Tree probe failed: {_tree_probed_files}") + _tree_probed_files = [] + _tree_probed_set: frozenset[str] = frozenset(_tree_probed_files) + + self.llm_usages.append(resp.usage) + if resp.usage and isinstance(resp.usage, dict): + context.add_llm_tokens( + resp.usage.get("total_tokens", 0), usage=resp.usage, + ) + + analysis = self._parse_fast_json(resp.content) + query_type = analysis.get("type", "search") + file_hints = analysis.get("file_hints", []) + + # Extract catalog-routed files from the fused response + if catalog: + selected_indices = analysis.get("selected_docs", []) + catalog_confidence = analysis.get("doc_confidence", "low") + for idx in selected_indices: + if isinstance(idx, int) and 0 <= idx < len(catalog): + fp = catalog[idx]["path"] + if Path(fp).exists(): + catalog_routed_files.append(fp) + if catalog_routed_files: + await self._logger.info( + f"[FAST:Step1] Catalog routing ({catalog_confidence}): " + f"{[Path(p).name for p in catalog_routed_files]}" + ) + + if query_type == "chat": + chat_reply = analysis.get("response", "") + if chat_reply: + await self._logger.info("[FAST:Step1] LLM classified as chat intent") + return chat_reply, None, context + return (await self._respond_chat(query, context)) + + if query_type == "summary": + await self._logger.info("[FAST:Step1] Summary intent detected — delegating to doc analysis") + # When user names a specific file, resolve it and skip dir scan + rank + summary_paths: Optional[List[str]] = None + if file_hints: + summary_paths = self._resolve_file_hints(paths, file_hints) + if summary_paths: + await self._logger.info( + f"[FAST:Summary] Resolved file hint(s) → {[Path(p).name for p in summary_paths]}" + ) + if summary_paths: + answer = await self._summarize_documents( + query, summary_paths, + top_k_files=len(summary_paths), + scan_result=None, + ) + if answer: + return answer, self._make_answer_cluster(query, answer, "FS", file_paths=summary_paths), context + # No hint or resolve failed: run dir scan (if enabled) then rank + summarize + scan_result = await self._probe_dir_scan(paths, enable=enable_dir_scan, + max_files=300) if enable_dir_scan else None + answer = await self._summarize_documents( + query, paths, + top_k_files=top_k_files, + scan_result=scan_result, + ) + if answer: + return answer, self._make_answer_cluster(query, answer, "FS", file_paths=paths), context + await self._logger.info("[FAST:Step1] Summary fallback — no documents, continuing search") + + primary = analysis.get("primary", [])[:2] + fallback = analysis.get("fallback", [])[:3] + primary_alt = analysis.get("primary_alt", [])[:2] + fallback_alt = analysis.get("fallback_alt", [])[:3] + + if primary_alt: + primary = primary + primary_alt + if fallback_alt: + fallback = fallback + fallback_alt + + # --- IDF weights from LLM --- + keyword_idfs: Dict[str, float] = analysis.get("idf", {}) + if not keyword_idfs: + all_kws = (primary or []) + (fallback or []) + keyword_idfs = {kw: max(0.5, min(1.0, len(kw) / 5.0)) for kw in all_kws} + + if not primary and not fallback: + await self._logger.warning("[FAST] No keywords extracted") + msg = f"Could not extract search terms from query: '{query}'" + return msg, None, context + + # ============================================================== + # Step 1.5: Compile-aware enrichment (P2 + P4, zero LLM calls) + # Catalog-routed files from the fused Step 1 are merged here. + # ============================================================== + all_kw_set = set(primary + fallback) + + # P2: inject soft-hit patterns as fallback keywords + if soft_hit: + for p in soft_hit.patterns: + if p not in all_kw_set: + fallback.append(p) + all_kw_set.add(p) + keyword_idfs.setdefault(p, 0.6) + + # P4: compile hints — pre-fetched (query-level) + keyword-level supplement + _kw_compile_hints = await self._probe_compile_hints(primary + fallback, scope=_scope) + compile_hints = self._merge_compile_hints(_early_compile_hints, _kw_compile_hints) + for kw in compile_hints.extra_keywords: + if kw not in all_kw_set: + fallback.append(kw) + all_kw_set.add(kw) + keyword_idfs.setdefault(kw, 0.5) + + compile_hint_files: List[str] = [] + # Catalog-routed files get highest priority + seen_hint_paths: set = set() + for fp in catalog_routed_files: + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) + # Active tree probe files: second priority (GAP 1) + for fp in (_tree_probed_files or []): + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) + # Summary index BM25 files: proactive zero-LLM discovery (GAP 2) + _summary_hint_files = await self._probe_summary_index(query, artifacts, scope=_scope) + for fp in _summary_hint_files: + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) + if soft_hit: + for fp in soft_hit.file_paths: + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) + for fp in compile_hints.file_paths: + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) + + if compile_hint_files: + await self._logger.info( + f"[FAST:Step1.5] Compile hints: {len(compile_hint_files)} files " + f"(catalog={len(catalog_routed_files)}, " + f"tree={len(_tree_probed_files) if _tree_probed_files else 0}, " + f"summary={len(_summary_hint_files)}, " + f"soft={len(soft_hit.file_paths) if soft_hit else 0}), " + f"{len(compile_hints.extra_keywords)} extra keywords" + ) + + await self._logger.info( + f"[FAST:Step1] Primary: {primary}, Fallback: {fallback}" + ) + + # ============================================================== + # Step 2: rga cascade — primary first, fallback only if needed + # When catalog routing has high confidence, catalog-routed files + # are used directly (skipping rga) to avoid noise from unrelated + # files. Otherwise rga runs first and catalog acts as fallback. + # ============================================================== + context.add_search(query) + include_patterns = list(include or []) + for hint in file_hints: + if "*" in hint or "." in hint: + include_patterns.append(hint) + + rga_kwargs = dict( + paths=paths, max_depth=max_depth, + include=include_patterns or None, exclude=exclude, + ) + + best_files: Optional[List[Dict[str, Any]]] = None + used_level = "primary" + evidence = "" + file_path: Optional[str] = None # set when best_files found + + # --- Pure tree search mode: skip rga, use tree probe results directly --- + if _PURE_TREE_SEARCH: + if _tree_probed_files: + used_level = "pure_tree" + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in _tree_probed_files[:top_k_files] + ] + print(f"SEARCH_WIKI_DEBUG [D7] _tree_probed_files={_tree_probed_files}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D8] best_files={[bf['path'] for bf in best_files]}", flush=True) + await self._logger.info( + f"[FAST:PureTree] Using {len(best_files)} tree-probed files: " + f"{[Path(p).name for p in _tree_probed_files[:top_k_files]]}" + ) + elif compile_hint_files: + # Tree probe returned nothing but compile hints have tree files + used_level = "pure_tree_hint" + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in compile_hint_files[:top_k_files] + ] + await self._logger.info( + f"[FAST:PureTree] No tree probes, falling back to " + f"{len(best_files)} compile-hint files" + ) + else: + # Graceful degradation: fall back to keyword search when no tree is available + await self._logger.info( + "[FAST:PureTree] No tree probes available, falling back to keyword search" + ) + best_files = await self._fast_find_best_file( + primary, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, **rga_kwargs, + ) + if not best_files and fallback: + best_files = await self._fast_find_best_file( + fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, **rga_kwargs, + ) + if not best_files: + return _NO_RESULTS_MESSAGE, None, context + else: + # --- Original rga-based retrieval logic --- + # High-confidence catalog routing: skip rga, use catalog directly + if catalog_routed_files and catalog_confidence == "high": + used_level = "catalog_route" + await self._logger.info( + f"[FAST:Step2] High-confidence catalog routing → " + f"{[Path(p).name for p in catalog_routed_files[:top_k_files]]}" + ) + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in catalog_routed_files[:top_k_files] + ] + + # Narrow-scope RGA: search within tree-probed files first + if not best_files and _tree_probed_set and primary: + best_files = await self._fast_find_best_file( + primary, paths=list(_tree_probed_set), + top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + ) + if best_files: + used_level = "tree_rga" + await self._logger.info( + f"[FAST:Step2] Narrow-scope tree+rga hit → " + f"{[Path(f['path']).name for f in best_files]}" + ) + + # Full-scope RGA with tree probe boost + if not best_files and primary: + best_files = await self._fast_find_best_file( + primary, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + tree_probed_paths=_tree_probed_set or None, + **rga_kwargs, + ) + + if not best_files and fallback: + used_level = "fallback" + await self._logger.info( + "[FAST:Step2] Primary miss, trying fine-grained fallback" + ) + best_files = await self._fast_find_best_file( + fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + tree_probed_paths=_tree_probed_set or None, + **rga_kwargs, + ) + + # --- Fallback: compile-hint files when rga misses (catalog + P2 + P4) --- + if not best_files and compile_hint_files: + used_level = "compile_hint" + await self._logger.info( + f"[FAST:Step2] rga miss — using {len(compile_hint_files)} compile-hint files" + ) + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in compile_hint_files[:top_k_files] + ] + + # --- Fallback: use dir_scan only when rga misses and dir scan is enabled --- + if not best_files and enable_dir_scan: + scan_result = await self._probe_dir_scan(paths, enable=True, max_files=300) + if scan_result is not None: + await self._logger.info("[FAST:Step2] rga miss — falling back to dir_scan ranking") + ranked_paths = await self._rank_dir_scan_candidates( + query, scan_result, top_k=10, include_medium=True, + ) + if ranked_paths: + used_level = "dir_scan" + best_files = [{"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} for p in ranked_paths[:top_k_files]] + + if not best_files: + if llm_fallback: + await self._logger.info( + "[FAST:Step2] No files found, llm_fallback=True \u2192 skip to LLM summary" + ) + evidence = self._LLM_FALLBACK_EVIDENCE + else: + await self._logger.warning( + f"[FAST:Step2] No matching files found in paths: {paths}. " + ) + return _NO_RESULTS_MESSAGE, None, context + + if best_files: + file_path = best_files[0]["path"] + match_objects = best_files[0].get("matches", []) + wiki_info = "" + if best_files[0].get("wiki_relevance") is not None: + wiki_info = f", wiki={best_files[0]['wiki_relevance']:.1f}" + await self._logger.info( + f"[FAST:Step2] Best file ({used_level}): {Path(file_path).name} " + f"({best_files[0].get('total_matches', 0)} hits, " + f"score={best_files[0].get('weighted_score', 0):.2f}{wiki_info})" + ) + + # ============================================================== + # Step 2.5 + Step 3: Tree navigation (1 LLM call) runs in + # parallel with rga evidence sampling (0 LLM). The merged + # result is higher quality than either alone. + # Tree-guided sampling is integrated into _rga_evidence() for + # secondary files; the primary file gets a dedicated parallel + # tree_task to avoid blocking rga. + # ============================================================== + + # Track files already receiving parallel tree navigation to + # avoid duplicate LLM calls inside _rga_evidence(). + tree_nav_done: Set[str] = set() + tree_nav_target = best_files[0]["path"] + + print(f"SEARCH_WIKI_DEBUG [D9] tree_nav_target={tree_nav_target}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D10] tree_nav_match={tree_nav_target in (artifacts.tree_available_paths if artifacts else set())}", flush=True) + if artifacts and tree_nav_target not in artifacts.tree_available_paths: + print(f"SEARCH_WIKI_DEBUG [D11] MISMATCH! tree_available_paths={artifacts.tree_available_paths}", flush=True) + + if artifacts and tree_nav_target in artifacts.tree_available_paths: + tree_task = self._navigate_tree_for_evidence( + tree_nav_target, query, + max_results=self._TREE_NAV_MAX_RESULTS, + match_objects=best_files[0].get("matches"), + ) + tree_nav_done.add(tree_nav_target) + else: + tree_task = self._async_noop(None) + + async def _rga_evidence() -> str: + """Collect evidence from best_files: tree-guided when available, rga fallback.""" + parts: List[str] = [] + chars = 0 + for bf in best_files: + if chars >= self._FAST_MAX_EVIDENCE_CHARS: + break + fp = bf["path"] + fn = Path(fp).name + ext = Path(fp).suffix.lower() + ev = None + + print(f"SEARCH_WIKI_DEBUG [D12] _rga_evidence: fp={fp}", flush=True) + + # 0. Excel digest priority (pre-compiled evidence) + if artifacts and artifacts.manifest_map: + manifest_entry = artifacts.manifest_map.get(fp) + if manifest_entry and getattr(manifest_entry, 'has_xlsx_digest', False): + digest_path = ( + self.work_path / ".cache" / "compile" / "xlsx_digests" + / f"{manifest_entry.file_hash}.txt" + ) + if digest_path.exists(): + try: + digest_content = digest_path.read_text(encoding="utf-8") + if digest_content.strip(): + ev = f"[{fn} - Pre-compiled Evidence]\n{digest_content}" + except Exception: + pass + + # 0.5 Table digest priority (pre-compiled PDF table evidence) + _all_tables = None + if ev is None and artifacts: + # Primary: manifest-based lookup + if artifacts.manifest_map: + _me = artifacts.manifest_map.get(fp) + if _me and getattr(_me, 'has_table_digest', False): + _all_tables = self._load_table_digest( + self.work_path, _me.file_hash, + ) + + # Fallback: direct hash-based lookup when manifest misses + if not _all_tables: + try: + from sirchmunk.utils.file_utils import get_fast_hash + _file_hash = get_fast_hash(fp) + if _file_hash: + _all_tables = self._load_table_digest( + self.work_path, _file_hash, + ) + except Exception: + pass + + print(f"SEARCH_WIKI_DEBUG [D13] table_digest: manifest_lookup={'found' if artifacts.manifest_map and artifacts.manifest_map.get(fp) else 'miss'}, has_table_digest={getattr(artifacts.manifest_map.get(fp), 'has_table_digest', False) if artifacts.manifest_map else 'N/A'}, hash_fallback={'tried' if not _all_tables else 'skipped'}, tables_count={len(_all_tables) if _all_tables else 0}", flush=True) + + if _all_tables: + _td_budget = ( + self._TABLE_EVIDENCE_NAV_OVERLAP_CHARS + if fp in tree_nav_done + else self._TABLE_EVIDENCE_DEFAULT_CHARS + ) + _table_ev = self._format_table_evidence( + _all_tables, + max_chars=_td_budget, + query=query, + ) + if _table_ev: + ev = f"[{fn} - Table Evidence]\n{_table_ev}" + + # 1. Tree-guided sampling for tree-indexed files + # (skipped when a parallel tree_task already covers this file) + _tree_cond = artifacts and fp in artifacts.tree_available_paths and fp not in tree_nav_done + print(f"SEARCH_WIKI_DEBUG [D14] tree_sample: cond={_tree_cond}, in_tree_paths={fp in (artifacts.tree_available_paths if artifacts else set())}, in_nav_done={fp in tree_nav_done}", flush=True) + if ( + artifacts + and fp in artifacts.tree_available_paths + and fp not in tree_nav_done + ): + try: + tree_ev_inner = await self._tree_guided_sample( + fp, query, + match_objects=bf.get("matches", []), + max_chars=self._FAST_MAX_EVIDENCE_CHARS - chars, + artifacts=artifacts, + ) + if tree_ev_inner: + if ev: + ev = ev + "\n\n" + tree_ev_inner + else: + ev = tree_ev_inner + await self._logger.info( + f"[FAST:Step3] Tree-guided sample for {fn} " + f"({len(tree_ev_inner)} chars)" + ) + except Exception: + pass + + # 2. Small file: read entirely (only if tree didn't provide evidence) + if ev is None and ext in self._FAST_TEXT_EXTENSIONS: + try: + sz = Path(fp).stat().st_size + if sz < self._FAST_SMALL_FILE_THRESHOLD: + full = Path(fp).read_text(errors="replace") + if len(full) < self._FAST_SMALL_FILE_THRESHOLD: + ev = f"[{fn}]\n{full}" + except Exception: + pass + + # 3. Fallback: rga sampling (existing logic) + if ev is None: + ev = await self._fast_sample_evidence(fp, bf.get("matches", [])) + + if ev: + remaining = self._FAST_MAX_EVIDENCE_CHARS - chars + parts.append(ev[:remaining]) + chars += len(parts[-1]) + context.mark_file_read(fp) + + _ev_source = "none" + if ev: + if "Table Evidence" in ev: _ev_source = "table_digest" + elif "Pre-compiled" in ev: _ev_source = "excel_digest" + elif "TreeSample" in str(ev)[:50] or "TreeNav" in str(ev)[:50]: _ev_source = "tree" + else: _ev_source = "rga_or_other" + print(f"SEARCH_WIKI_DEBUG [D15] ev_source={_ev_source}, ev_len={len(ev) if ev else 0}", flush=True) + return "\n\n---\n\n".join(parts) + + # Launch tree navigation alongside rga evidence collection. + rga_ev, tree_ev = await asyncio.gather(_rga_evidence(), tree_task) + + # Merge: tree evidence first (highest quality), then rga + if tree_ev and rga_ev: + rga_ev = self._deduplicate_table_sections(tree_ev, rga_ev) + evidence_parts_final: List[str] = [] + if tree_ev: + evidence_parts_final.append(tree_ev) + if rga_ev: + evidence_parts_final.append(rga_ev) + evidence = "\n\n---\n\n".join(evidence_parts_final) + + print(f"SEARCH_WIKI_DEBUG [D16] tree_ev: {'yes' if tree_ev else 'no'}, len={len(tree_ev) if tree_ev else 0}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D17] rga_ev: {'yes' if rga_ev else 'no'}, len={len(rga_ev) if rga_ev else 0}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D18] final_evidence_len={len(evidence)}", flush=True) + + if not evidence or len(evidence.strip()) < 20: + if llm_fallback: + await self._logger.info( + "[FAST:Step3] No usable evidence, llm_fallback=True → LLM summary" + ) + evidence = self._LLM_FALLBACK_EVIDENCE + else: + await self._logger.warning("[FAST:Step3] No usable evidence extracted") + return _NO_RESULTS_MESSAGE, None, context + + tree_available = file_path in artifacts.tree_available_paths if artifacts else False + await self._logger.info( + f"[FAST:Step3] Evidence: {len(evidence)} chars " + f"(tree={'yes' if tree_ev else 'no'}, rga={'yes' if rga_ev else 'no'}, " + f"tree_indexed={'yes' if tree_available else 'no'})" + ) + + keywords_used = primary if used_level == "primary" else fallback + + # ============================================================== + # Step 4: LLM answer from focused evidence (single call) + # Wiki-enhanced: inject document context when catalog available. + # ============================================================== + doc_context = self._build_answer_context(file_path, artifacts) if best_files else None + if doc_context: + from sirchmunk.llm.prompts import ROI_RESULT_SUMMARY_WITH_CONTEXT + answer_prompt = ROI_RESULT_SUMMARY_WITH_CONTEXT.format( + user_input=query, + text_content=evidence, + document_context=doc_context, + ) + await self._logger.info( + f"[FAST:Step4] Wiki-enhanced answer generation with catalog context" + ) + else: + answer_prompt = ROI_RESULT_SUMMARY.format( + user_input=query, + text_content=evidence, + ) + answer_resp = await self.llm.achat( + messages=[{"role": "user", "content": answer_prompt}], + stream=True, + ) + self.llm_usages.append(answer_resp.usage) + if answer_resp.usage and isinstance(answer_resp.usage, dict): + context.add_llm_tokens( + answer_resp.usage.get("total_tokens", 0), usage=answer_resp.usage, + ) + + answer, should_save, should_answer = self._parse_summary_response( + answer_resp.content or "" + ) + + # --- Multi-factor evidence acceptance (P2+P3+P4) --- + accepted, accept_reason = self._evaluate_evidence_acceptance( + query, evidence, should_answer, + ) + await self._logger.info( + f"[FAST:Step4] Evidence acceptance: {accepted} ({accept_reason})" + ) + + # ============================================================== + # Step 5: Self-correction retry (conditional, ≤1 extra LLM call) + # When the answer gate rejects the first attempt, try alternative + # evidence sources before giving up. + # ============================================================== + if not accepted: + retry_evidence = await self._fast_self_correct( + query, best_files, catalog_routed_files, context, + ) + if retry_evidence: + await self._logger.info( + f"[FAST:Step5] Retrying with {len(retry_evidence)} chars of alternative evidence" + ) + retry_prompt = ROI_RESULT_SUMMARY.format( + user_input=query, text_content=retry_evidence, + ) + retry_resp = await self.llm.achat( + messages=[{"role": "user", "content": retry_prompt}], + stream=True, + ) + self.llm_usages.append(retry_resp.usage) + if retry_resp.usage and isinstance(retry_resp.usage, dict): + context.add_llm_tokens( + retry_resp.usage.get("total_tokens", 0), usage=retry_resp.usage, + ) + answer, should_save, retry_should_answer = self._parse_summary_response( + retry_resp.content or "" + ) + retry_accepted, retry_reason = self._evaluate_evidence_acceptance( + query, retry_evidence, retry_should_answer, + ) + await self._logger.info( + f"[FAST:Step5] Retry evidence acceptance: {retry_accepted} ({retry_reason})" + ) + if retry_accepted: + accepted = True + + if not accepted: + if llm_fallback: + await self._logger.info( + "[FAST:Step5] Retry also rejected, llm_fallback=True → LLM fallback" + ) + answer, should_save = await self._summarise_fast_fallback(query, context) + else: + await self._logger.warning( + "[FAST:Step5] Evidence rejected after retry, llm_fallback=False " + "→ returning no results" + ) + return _NO_RESULTS_MESSAGE, None, context + + if not should_save: + await self._logger.info("[FAST] Quality gate: low-quality answer, skipping cluster save") + await self._logger.success("[FAST] Search complete (no persist)") + return answer, None, context + + cluster = self._build_fast_cluster( + query, answer, file_path or "", evidence, keywords_used, + ) + self._add_query_to_cluster(cluster, query) + try: + await self._save_cluster_with_embedding(cluster) + except Exception as exc: + _loguru_logger.warning( + f"[FAST] Failed to save cluster with embedding: {exc}" + ) + + await self._logger.success("[FAST] Search complete") + return answer, cluster, context + + # ---- FAST helpers ---- + + @staticmethod + def _count_keyword_tf_per_file(raw_results: List[Dict[str, Any]]) -> Dict[str, int]: + """Count matches per file from rga JSON output.""" + counts: Dict[str, int] = {} + current_path: Optional[str] = None + for item in raw_results: + item_type = item.get("type") + if item_type == "begin": + current_path = item.get("data", {}).get("path", {}).get("text") + elif item_type == "match" and current_path is not None: + counts[current_path] = counts.get(current_path, 0) + 1 + elif item_type == "end": + current_path = None + return counts + + @staticmethod + def _dedup_merged_files( + merged: List[Dict[str, Any]], + per_file_kw_tf: Dict[str, Dict[str, int]], + match_limit: int = 20, + ) -> List[Dict[str, Any]]: + """Deduplicate merged file entries by path, combining matches from + multiple keyword searches into a single entry per file. + + When the same file appears in multiple rga begin/end groups (one per + keyword search), this merges them so downstream scoring and evidence + extraction operate on a single, complete representation. + + Args: + merged: File entries from GrepRetriever.merge_results(), may + contain duplicates. + per_file_kw_tf: Pre-computed per-file keyword TF counts (not + modified, used only for reference). + match_limit: Maximum matches to keep per file after merging. + + Returns: + Deduplicated list with one entry per unique file path. + """ + if not merged: + return merged + + seen: Dict[str, int] = {} # path -> index in deduped + deduped: List[Dict[str, Any]] = [] + + for entry in merged: + fpath = entry["path"] + if fpath in seen: + # Merge into existing entry + idx = seen[fpath] + existing = deduped[idx] + existing["matches"].extend(entry.get("matches", [])) + existing["lines"].extend(entry.get("lines", [])) + existing["total_matches"] += entry.get("total_matches", 0) + else: + # New file — clone to avoid mutating original + seen[fpath] = len(deduped) + deduped.append({ + "path": fpath, + "matches": list(entry.get("matches", [])), + "lines": list(entry.get("lines", [])), + "total_matches": entry.get("total_matches", 0), + "total_score": entry.get("total_score", 0.0), + }) + + # Trim matches to limit per file + for entry in deduped: + if len(entry["matches"]) > match_limit: + # Sort by score descending, keep top + entry["matches"].sort( + key=lambda x: x.get("score", 0.0), reverse=True + ) + entry["matches"] = entry["matches"][:match_limit] + + return deduped + + @staticmethod + def _prune_by_score( + candidates: List[Dict[str, Any]], + top_k: int = 3, + relative_ratio: float = 0.30, + gap_ratio: float = 0.50, + min_count: int = 1, + ) -> List[Dict[str, Any]]: + """Dynamically prune ranked file candidates by score distribution. + + Applies a three-stage filter to remove clearly irrelevant files: + + 1. **Relative threshold**: Discard files scoring below + ``max_score * relative_ratio`` (default 30%). + 2. **Gap detection**: Scan adjacently ranked files; when the score + drop from one to the next exceeds ``prev_score * gap_ratio`` + (default 50%), truncate the list at that point. + 3. **Minimum guarantee**: Ensure at least ``min_count`` files + survive (default 1). + + Finally the result is capped at ``top_k``. + + Args: + candidates: File dicts sorted by ``weighted_score`` descending. + top_k: Maximum number of files to return. + relative_ratio: Fraction of the top score used as a floor. + gap_ratio: Maximum tolerated relative drop between adjacent + candidates. + min_count: Minimum number of candidates to keep regardless of + score. + + Returns: + Pruned list of candidates (length in [min_count, top_k]). + """ + if not candidates: + return [] + + max_score = candidates[0].get("weighted_score", 0.0) + + # Step 1: Relative threshold filter + threshold = max_score * relative_ratio + filtered = [f for f in candidates if f.get("weighted_score", 0.0) >= threshold] + if not filtered: + filtered = candidates[:min_count] + + # Step 2: Gap detection truncation + result = [filtered[0]] + for i in range(1, len(filtered)): + prev_score = filtered[i - 1].get("weighted_score", 0.0) + curr_score = filtered[i].get("weighted_score", 0.0) + if prev_score > 0 and (prev_score - curr_score) > prev_score * gap_ratio: + break + result.append(filtered[i]) + + # Step 3: Minimum guarantee + if len(result) < min_count and len(filtered) >= min_count: + result = filtered[:min_count] + + # Cap at top_k + return result[:top_k] + + @staticmethod + def _compute_wiki_relevance( + file_path: str, + query: str, + keywords: List[str], + catalog_map: Dict[str, Dict[str, str]], + tree_available_paths: Set[str], + ) -> float: + """Compute wiki-based relevance score for a candidate file (0-10 scale). + + Uses three sub-scores derived from compile artifacts: + + 1. **Catalog summary overlap** (0-``_WIKI_CATALOG_KEYWORD_OVERLAP_MAX``): + proportion of query keywords that appear in the catalog entry's + summary. When *keywords* is empty, falls back to whole-query + substring matching against the summary to avoid returning 0 for + valid queries. + 2. **Tree availability bonus** (0-``_WIKI_TREE_AVAILABILITY_BONUS``): + a file with a compiled tree index likely has rich structure. + 3. **Catalog presence bonus** (0-``_WIKI_CATALOG_PRESENCE_FULL``): + files important enough to be in the catalog get a baseline boost. + + All scoring is pure text matching — no LLM, no embedding. + + Args: + file_path: Absolute path of the candidate file. + query: Original user query. + keywords: Extracted search keywords from FAST Step 1. + catalog_map: ``{path: catalog_entry}`` from CompileArtifacts. + tree_available_paths: Set of file paths with cached tree indices. + + Returns: + Float in [0, 10] representing wiki-derived relevance. + """ + cls = AgenticSearch # access class constants from static method + score = 0.0 + + entry = catalog_map.get(file_path) + + # Sub-score 1: Catalog summary keyword overlap + if entry: + summary_lower = (entry.get("summary", "") + " " + entry.get("name", "")).lower() + query_lower = query.lower() + matches = 0 + total = 0 + summary_tokens = cls._tokenize_for_matching(summary_lower) + for kw in keywords: + if kw: + total += 1 + kw_low = kw.lower() + if kw_low in summary_tokens: + matches += 1 # Full token match + elif kw_low in summary_lower: + matches += 0.5 # Substring-only match (lower confidence) + # Also check whole query as a substring + if len(query_lower) >= 2 and query_lower in summary_lower: + matches += 1 + total += 1 + # When keywords list is empty but query is non-empty, fall back to + # character-level overlap so the sub-score is not silently 0. + if total == 0 and query_lower: + # Simple overlap: count how many query chars appear in summary + overlap = sum(1 for ch in query_lower if ch in summary_lower) + ratio = overlap / max(len(query_lower), 1) + score += ratio * cls._WIKI_CATALOG_KEYWORD_OVERLAP_MAX + elif total > 0: + score += (matches / total) * cls._WIKI_CATALOG_KEYWORD_OVERLAP_MAX + + # Sub-score 2: Tree availability bonus + if file_path in tree_available_paths: + score += cls._WIKI_TREE_AVAILABILITY_BONUS + + # Sub-score 3: Catalog presence bonus + if entry: + summary_len = len(entry.get("summary", "")) + if summary_len > 100: + score += cls._WIKI_CATALOG_PRESENCE_FULL + elif summary_len > 30: + score += cls._WIKI_CATALOG_PRESENCE_MEDIUM + elif summary_len > 0: + score += cls._WIKI_CATALOG_PRESENCE_MINIMAL + + return min(score, cls._WIKI_MAX_SCORE) + + async def _fast_find_best_file( + self, + keywords: List[str], + paths: List[str], + max_depth: Optional[int] = 5, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + top_k: int = 1, + keyword_idfs: Optional[Dict[str, float]] = None, + query: str = "", + artifacts: Optional["CompileArtifacts"] = None, + tree_probed_paths: Optional[Set[str]] = None, + ) -> Optional[List[Dict[str, Any]]]: + """Search per keyword via rga and return the top-k best-matching files + ranked by IDF-weighted log-TF scoring, optionally enhanced with + wiki-derived relevance from compile artifacts. + + When *tree_probed_paths* is provided, files that were selected by + LLM-driven tree probing receive a ranking boost, ensuring the tree + probe's high-quality signal influences the final file ordering. + + Args: + keywords: Search keywords from FAST Step 1. + paths: Search paths. + max_depth: Maximum directory depth for rga. + include: Glob patterns to include. + exclude: Glob patterns to exclude. + top_k: Number of top files to return. + keyword_idfs: Pre-computed IDF values for keywords. + query: Original user query (used for wiki relevance scoring). + artifacts: Compile artifacts for adaptive wiki-enhanced ranking. + tree_probed_paths: File paths selected by tree probing (receive boost). + + Returns: + List of merged file dicts (path, matches, lines, total_matches, weighted_score) or None. + """ + all_raw: List[Dict[str, Any]] = [] + per_file_kw_tf: Dict[str, Dict[str, int]] = {} # {file_path: {keyword: count}} + + for kw in keywords: + try: + results = await self.grep_retriever.retrieve( + terms=kw, path=paths, literal=True, regex=False, + max_depth=max_depth, include=include, exclude=exclude, + timeout=30.0, + ) + if results: + all_raw.extend(results) + # Track per-file TF for this keyword + kw_counts = self._count_keyword_tf_per_file(results) + for fpath, count in kw_counts.items(): + per_file_kw_tf.setdefault(fpath, {})[kw] = count + except Exception as exc: + await self._logger.warning( + f"[FAST] rga literal search failed for '{kw}': {exc}" + ) + + # Fallback: escaped-regex OR (handles adapters that only work in regex mode) + if not all_raw and keywords: + try: + escaped = [re.escape(kw) for kw in keywords] + pattern = "|".join(escaped) + results = await self.grep_retriever.retrieve( + terms=pattern, path=paths, literal=False, regex=True, + max_depth=max_depth, include=include, exclude=exclude, + timeout=30.0, + ) + if results: + all_raw.extend(results) + # For regex OR fallback, attribute matches to individual keywords + # by checking which keywords appear in each match line + # (simplified: count total matches per file, distribute proportionally) + regex_counts = self._count_keyword_tf_per_file(results) + for fpath, count in regex_counts.items(): + # Attribute to all keywords equally (approximation for OR regex) + per_kw_share = max(1, count // len(keywords)) if keywords else count + for kw in keywords: + existing = per_file_kw_tf.get(fpath, {}).get(kw, 0) + if existing == 0: # Only fill if not already set by literal search + per_file_kw_tf.setdefault(fpath, {})[kw] = per_kw_share + except Exception as exc: + await self._logger.warning( + f"[FAST] rga regex search failed: {exc}" + ) + + # Fallback: filename search + if not all_raw: + try: + fn_results = await self.grep_retriever.retrieve_by_filename( + patterns=[f".*{re.escape(kw)}.*" for kw in keywords], + path=paths, case_sensitive=False, max_depth=max_depth, + timeout=30.0, + ) + if fn_results: + return [{"path": fn_results[0]["path"], "matches": [], "lines": [], "total_matches": 0, "weighted_score": 0.0}] + except Exception as exc: + await self._logger.warning( + f"[FAST] filename search failed: {exc}" + ) + + # Layer 4: Embedding + BM25 hybrid fallback + # Triggered ONLY when layers 1-3 all return empty results + if (not all_raw + and self._ENABLE_EMBEDDING_FALLBACK + and artifacts is not None + and artifacts.summary_index is not None): + try: + query_emb = None + query_tokens: List[str] = [] + + # Compute query embedding (if embedding client available) + if (self.embedding_client + and self.embedding_client.is_ready() + and artifacts.summary_index.has_embeddings): + query_emb = (await self.embedding_client.embed([query]))[0] + + # Tokenize query for BM25 + from sirchmunk.utils.tokenizer_util import TokenizerUtil + _tokenizer = TokenizerUtil() + query_tokens = _tokenizer.segment(query) + + if query_emb is not None or query_tokens: + results = artifacts.summary_index.search( + query_embedding=query_emb, + query_tokens=query_tokens, + top_k=top_k or 3, + ) + + for file_path, score in results: + if Path(file_path).exists(): + all_raw.append({ + "path": file_path, + "matches": [], + "weighted_score": score * self._WIKI_MAX_SCORE, + }) + + if all_raw: + await self._logger.info( + f"[FAST] Embedding+BM25 fallback found {len(all_raw)} candidates" + ) + except Exception as exc: + await self._logger.warning( + f"[FAST] Embedding+BM25 fallback failed: {exc}" + ) + + if not all_raw: + return None + + merged = GrepRetriever.merge_results(all_raw, limit=20) + if not merged: + return None + + # Deduplicate file entries from multi-keyword searches + merged = self._dedup_merged_files(merged, per_file_kw_tf) + + # --- IDF × (1 + log TF) weighted scoring --- + _idfs = keyword_idfs or {} + for f in merged: + fpath = f["path"] + kw_tf = per_file_kw_tf.get(fpath, {}) + score = 0.0 + for kw in keywords: + tf = kw_tf.get(kw, 0) + if tf > 0: + idf = _idfs.get(kw, max(0.5, min(1.0, len(kw) / 5.0))) + score += idf * (1.0 + math.log(tf)) + f["weighted_score"] = score + + # --- Wiki-enhanced hybrid scoring (adaptive: only when artifacts exist) --- + if artifacts and artifacts.catalog_map: + # Normalize TF-IDF scores to [0, 10] to align with Wiki score range + max_tf_idf = max((f["weighted_score"] for f in merged), default=1.0) + if max_tf_idf <= 0: + max_tf_idf = 1.0 + for f in merged: + wiki_score = self._compute_wiki_relevance( + f["path"], query, keywords, + artifacts.catalog_map, artifacts.tree_available_paths, + ) + f["wiki_relevance"] = wiki_score + # Normalize TF-IDF to [0, 10] before blending + tf_idf_norm = (f["weighted_score"] / max_tf_idf) * self._WIKI_MAX_SCORE + f["weighted_score"] = ( + self._WIKI_BLEND_ALPHA * tf_idf_norm + + (1 - self._WIKI_BLEND_ALPHA) * wiki_score + ) + + if tree_probed_paths: + for f in merged: + if f["path"] in tree_probed_paths: + f["weighted_score"] += self._TREE_PROBE_RANKING_BOOST + + merged.sort(key=lambda f: f["weighted_score"], reverse=True) + pruned = self._prune_by_score(merged, top_k=top_k) + + return pruned if pruned else None + + async def _fast_sample_evidence( + self, + file_path: str, + match_objects: List[Dict[str, Any]], + ) -> str: + """Build focused evidence from grep hits: context windows for text + files, raw match snippets for binary formats. + + Args: + file_path: Absolute path to the best file. + match_objects: Match event dicts from ``merge_results``. + + Returns: + Formatted evidence string. + """ + fname = Path(file_path).name + ext = Path(file_path).suffix.lower() + + # Extract match line numbers + hit_lines: List[int] = [] + for m in match_objects: + ln = m.get("data", {}).get("line_number") + if isinstance(ln, int): + hit_lines.append(ln) + + # Diagnostic logging when falling back to snippet mode + if not hit_lines and match_objects: + await self._logger.info( + f"[FAST] No line_number in {len(match_objects)} match(es) for {fname}, " + f"falling back to snippet mode" + ) + + # --- Text files: read context windows around hits --- + if ext in self._FAST_TEXT_EXTENSIONS and hit_lines: + # Expand context window for sparse hits + window = self._FAST_CONTEXT_WINDOW + if len(hit_lines) <= 2: + window = max(window, 100) # ±100 lines for 1-2 hits + evidence = self._read_context_windows( + file_path, hit_lines, + window=window, + max_chars=self._FAST_MAX_EVIDENCE_CHARS, + ) + if evidence: + full_evidence = f"[{fname}]\n{evidence}" + if len(full_evidence) < 100: + await self._logger.info( + f"[FAST] Context window evidence too thin ({len(full_evidence)} chars) for {fname}, " + f"attempting file head extraction" + ) + head_evidence = await self._fast_read_file_head(file_path) + if head_evidence and len(head_evidence) > len(full_evidence): + return head_evidence + return full_evidence + + # --- Non-text files or no line numbers: use grep snippets --- + snippets: List[str] = [] + total = 0 + for m in match_objects: + line_text = m.get("data", {}).get("lines", {}).get("text", "").rstrip() + if not line_text: + continue + snippets.append(line_text) + total += len(line_text) + if total >= self._FAST_MAX_EVIDENCE_CHARS: + break + + if snippets: + snippet_evidence = f"[{fname}]\n" + "\n".join(snippets) + # If snippet evidence is too thin, try file head for richer context + if len(snippet_evidence) < 100: + await self._logger.info( + f"[FAST] Evidence too thin ({len(snippet_evidence)} chars) for {fname}, " + f"attempting file head extraction" + ) + head_evidence = await self._fast_read_file_head(file_path) + if head_evidence and len(head_evidence) > len(snippet_evidence): + return head_evidence + return snippet_evidence + + # Last resort: try reading file head + return await self._fast_read_file_head(file_path) + + @staticmethod + def _read_context_windows( + file_path: str, + hit_lines: List[int], + window: int = 30, + max_chars: int = 15_000, + ) -> Optional[str]: + """Read context windows around *hit_lines* from a text file. + + Merges overlapping windows to avoid duplication. Stops when + *max_chars* is reached. + """ + # Merge overlapping intervals + intervals = sorted(set( + (max(1, ln - window), ln + window) for ln in hit_lines + )) + merged: List[tuple] = [intervals[0]] + for start, end in intervals[1:]: + if start <= merged[-1][1] + 1: + merged[-1] = (merged[-1][0], max(merged[-1][1], end)) + else: + merged.append((start, end)) + + # Read file and extract windows + try: + with open(file_path, "r", encoding="utf-8", errors="replace") as f: + all_lines = f.readlines() + except Exception: + return None + + parts: List[str] = [] + total = 0 + for start, end in merged: + s = max(0, start - 1) # 0-indexed + e = min(len(all_lines), end) + chunk = "".join(all_lines[s:e]) + if total + len(chunk) > max_chars: + remaining = max_chars - total + if remaining > 200: + chunk = chunk[:remaining] + "\n[...truncated...]" + parts.append(chunk) + break + parts.append(chunk) + total += len(chunk) + + if not parts: + return None + + # Join windows with separator when there are gaps + return "\n[...]\n".join(parts) + + @classmethod + async def _fast_read_file_head( + cls, file_path: str, max_chars: int = 8_000, + ) -> str: + """Read the head of a file as last-resort evidence.""" + try: + p = Path(file_path) + if p.suffix.lower() in cls._FAST_TEXT_EXTENSIONS: + text = p.read_text(encoding="utf-8", errors="replace") + else: + from sirchmunk.utils.file_utils import fast_extract + result = await fast_extract(file_path) + text = result.content if result and result.content else "" + if text: + return f"[{p.name}]\n{text[:max_chars]}" + except Exception: + pass + return "" + + def _load_document_catalog(self) -> Optional[List[Dict[str, str]]]: + """Load the compiled document catalog for fused query+route prompt. + + Returns None when compile has not been run or catalog is missing. + """ + catalog_path = self.work_path / ".cache" / "compile" / "document_catalog.json" + if not catalog_path.exists(): + return None + try: + entries = json.loads(catalog_path.read_text(encoding="utf-8")) + if isinstance(entries, list) and entries: + return entries + except Exception: + pass + return None + + def _detect_compile_artifacts( + self, + search_paths: Optional[List[str]] = None, + ) -> CompileArtifacts: + """One-shot probe of all compile artifacts for adaptive FAST activation. + + Reads the document catalog and scans the tree cache directory to + determine which compile products are available. Called once at the + start of ``_search_fast()``; the result is passed to downstream + helpers so they can enable enhanced logic only when artifacts exist. + + When *search_paths* is provided, all returned artifacts are filtered + to only include entries whose file paths fall within the search scope. + This ensures downstream consumers (catalog routing, tree probing, + summary index) never see documents outside the requested scope. + + Cost: one JSON read (catalog) + one directory listing (tree cache). + Tree path results are cached in ``_tree_paths_cache`` so subsequent + calls within the same instance avoid re-parsing every JSON file. + Returns a ``CompileArtifacts`` with ``None``/empty fields when + compile has not been run. + """ + scope = _PathScope(search_paths) + + catalog = self._load_document_catalog() + catalog_map: Dict[str, Dict[str, str]] = {} + if catalog: + for entry in catalog: + p = entry.get("path", "") + if p: + catalog_map[p] = entry + + # Load manifest for rich metadata (size, has_tree, cluster_ids) + manifest_map: Dict[str, Any] = {} + manifest_path = self.work_path / ".cache" / "compile" / "manifest.json" + if manifest_path.exists(): + try: + from sirchmunk.learnings.compiler import CompileManifest + manifest = CompileManifest.from_json( + manifest_path.read_text(encoding="utf-8") + ) + manifest_map = manifest.files # {file_path: FileManifestEntry} + except Exception: + pass + + indexer = self._get_tree_indexer() + # Use cached tree paths when available to avoid re-parsing all JSONs + tree_paths: Set[str] = getattr(self, "_tree_paths_cache", None) or set() + if not tree_paths: + # Prefer manifest-based detection (fast, O(1) per file) + if manifest_map: + tree_paths = {fp for fp, entry in manifest_map.items() if entry.has_tree} + # Always try directory fallback if manifest-based detection found nothing + if not tree_paths and indexer is not None: + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if tree_cache.exists(): + try: + from sirchmunk.learnings.tree_indexer import DocumentTree + for tf in sorted(tree_cache.glob("*.json"))[:self._TREE_CACHE_SCAN_LIMIT]: + try: + tree = DocumentTree.from_json( + tf.read_text(encoding="utf-8") + ) + if tree.file_path: + tree_paths.add(tree.file_path) + except Exception: + pass + except Exception: + pass + # Cache for future calls within this instance + self._tree_paths_cache = tree_paths + + # Load summary index for embedding fallback (optional) + summary_index = None + summary_index_path = self.work_path / ".cache" / "compile" / "summary_index.json" + if summary_index_path.exists(): + try: + from sirchmunk.learnings.summary_index import CompileSummaryIndex + summary_index = CompileSummaryIndex.load(summary_index_path) + except Exception: + pass + + # --- Apply search-path scope filtering --- + if not scope.is_empty: + if catalog: + catalog = [e for e in catalog if scope.contains(e.get("path", ""))] + catalog_map = {p: e for p, e in catalog_map.items() if scope.contains(p)} + tree_paths = {p for p in tree_paths if scope.contains(p)} + manifest_map = {p: e for p, e in manifest_map.items() if scope.contains(p)} + + print(f"SEARCH_WIKI_DEBUG [D1] manifest_map: {len(manifest_map)} entries, keys={list(manifest_map.keys())[:3]}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D2] tree_available_paths: {tree_paths}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D3] manifest_fallback_executed: {manifest_map and not tree_paths}", flush=True) + return CompileArtifacts( + catalog=catalog, + catalog_map=catalog_map, + tree_indexer=indexer, + tree_available_paths=tree_paths, + manifest_map=manifest_map, + summary_index=summary_index, + ) + + def _build_tree_root_hints(self, artifacts: CompileArtifacts) -> str: + """Build tree root summary hints for FAST Step 1 query analysis. + + Loads root summaries from cached trees and formats them as context + for the LLM to understand document-level structure. + + Args: + artifacts: Compile artifact context with tree metadata. + + Returns: + Formatted hint string, or empty string when no trees are available. + """ + if not artifacts.tree_available_paths: + return "" + indexer = artifacts.tree_indexer + if indexer is None: + return "" + hints: List[str] = [] + for i, fp in enumerate(sorted(artifacts.tree_available_paths)): + if i >= self._TREE_ROOT_HINTS_MAX_FILES: + break + tree = indexer.load_tree(fp) + if tree and tree.root and tree.root.summary: + name = Path(fp).name + hints.append(f"[{i}] {name}: {tree.root.summary[:self._TREE_ROOT_HINT_TRUNCATE]}") + if not hints: + return "" + return "\nDocument structure hints:\n" + "\n".join(hints) + "\n" + + @staticmethod + def _tokenize_for_matching(text: str) -> Set[str]: + """Tokenize text into meaningful units for keyword matching. + + Splits on whitespace and CJK/Latin punctuation boundaries, then + generates 2-3 char n-grams for CJK-heavy tokens to handle + unsegmented Chinese text. Returns a set of lowercased tokens. + """ + import re + tokens: Set[str] = set() + raw = re.split(r'[\s,;.!?,;。!?::、\u201c\u201d\u2018\u2019()()\[\]{}<>《》\-/]+', text.lower()) + for t in raw: + t = t.strip() + if not t: + continue + tokens.add(t) + if len(t) >= 2 and any('\u4e00' <= c <= '\u9fff' for c in t): + for n in (2, 3): + for i in range(len(t) - n + 1): + tokens.add(t[i:i + n]) + return tokens + + @staticmethod + def _extract_catalog_keywords(summary: str, max_kw: int = 3) -> List[str]: + """Extract salient keywords from a catalog summary via simple heuristics. + + Uses word-length filtering, Chinese character detection, and CJK n-gram + extraction to pick the most informative tokens. For CJK-heavy text + (which does not use whitespace word boundaries), consecutive CJK + character runs are extracted as additional candidate tokens. + + No LLM or embedding involved. + + Args: + summary: Document summary text from the compiled catalog. + max_kw: Maximum number of keywords to return. + + Returns: + List of up to *max_kw* keywords. + """ + cls = AgenticSearch + if max_kw <= 0: + return [] + summary_text = str(summary or "").strip() + if not summary_text: + return [] + import re as _re + + # Split on whitespace and common punctuation (incl. CJK punctuation) + tokens = _re.split( + r'[\s,;\uff0c\uff1b\u3001\u3002\uff1a:!?\uff01\uff1f()\[\]{}\u201c\u201d\u2018\u2019\u0022\u0027/\\|`~@#$%^&*=+<>]+', + summary_text, + ) + + # For CJK text, also extract consecutive CJK character runs (2-6 chars) + # so that e.g. "停车位申请条件" yields ["停车位申请条件", "停车位", "申请条件", ...] + cjk_runs = _re.findall(r'[\u4e00-\u9fff\u3400-\u4dbf]{2,}', summary_text) + # Generate sub-phrases from long CJK runs (bigrams/trigrams/4-grams) + cjk_ngrams: List[str] = [] + max_ngram_per_run = 40 + for run in cjk_runs: + cjk_ngrams.append(run) + if len(run) > 4: + # Extract 2-4 char sub-phrases from each run + added = 0 + for n in (4, 3, 2): + for i in range(len(run) - n + 1): + cjk_ngrams.append(run[i:i + n]) + added += 1 + if added >= max_ngram_per_run: + break + if added >= max_ngram_per_run: + break + + tokens = tokens + cjk_ngrams + + # Filter: keep tokens with appropriate length and not purely numeric + candidates = [ + t for t in tokens + if t + and len(t) >= cls._CATALOG_KEYWORD_MIN_LEN + and not t.isdigit() + and len(t) <= cls._CATALOG_KEYWORD_MAX_LEN + and not _re.fullmatch(r"[_\-.]+", t) + ] + # Prefer longer tokens (more specific) + candidates.sort(key=len, reverse=True) + # Deduplicate case-insensitively + seen: Set[str] = set() + chosen_norms: List[str] = [] + result: List[str] = [] + for c in candidates: + lower = c.lower() + if lower not in seen: + # Avoid noisy micro-fragments when a longer token already exists. + if len(lower) <= 4 and any(lower in kept for kept in chosen_norms): + continue + seen.add(lower) + chosen_norms.append(lower) + result.append(c) + if len(result) >= max_kw: + break + return result + + def _build_enriched_catalog_listing( + self, + catalog: List[Dict[str, str]], + max_entries: Optional[int] = None, + ) -> str: + """Build an enriched catalog listing with keywords for FAST Step 1. + + Compared to the plain ``[i] name: summary[:200]`` format, this adds + extracted keywords to help the LLM make more informed document + selections. + + Args: + catalog: Entries from ``document_catalog.json``. + max_entries: Cap to prevent prompt overflow. + + Returns: + Formatted listing string for injection into the FAST query + analysis prompt. + """ + if not isinstance(catalog, list) or not catalog: + return "" + lines: List[str] = [] + _max = max_entries if max_entries is not None else self._CATALOG_LISTING_MAX_ENTRIES + if _max <= 0: + return "" + _trunc = self._CATALOG_SUMMARY_TRUNCATE + for i, entry in enumerate(catalog[:_max]): + if not isinstance(entry, dict): + continue + name = str(entry.get("name") or entry.get("path") or "") + summary = str(entry.get("summary") or "") + # Keep one-line prompt entries to avoid accidental prompt pollution. + name = " ".join(name.split()) + summary = " ".join(summary.split()) + if not name: + name = f"doc_{i}" + kws = AgenticSearch._extract_catalog_keywords(summary) + kw_str = ", ".join(kws) if kws else "" + shown_summary = summary[:_trunc] + if len(summary) > _trunc: + shown_summary += "..." + if kw_str: + lines.append(f"[{i}] {name}: {shown_summary} [Keywords: {kw_str}]") + else: + lines.append(f"[{i}] {name}: {shown_summary}") + return "\n".join(lines) + + def _build_answer_context( + self, + best_file_path: str, + artifacts: CompileArtifacts, + ) -> Optional[str]: + """Build document context from catalog for wiki-enhanced answer generation. + + Returns a short context string describing the source document, or + None when no catalog entry exists for *best_file_path*. + + Args: + best_file_path: Path of the top-ranked file from Step 2. + artifacts: Compile artifact availability context. + + Returns: + Context string or None. + """ + if not artifacts.catalog_map: + return None + entry = artifacts.catalog_map.get(best_file_path) + if not entry: + return None + name = entry.get("name", Path(best_file_path).name) + summary = entry.get("summary", "") + if not summary: + return None + return f"Source Document: {name}\nDocument Overview: {summary}" + + async def _tree_guided_sample( + self, + file_path: str, + query: str, + *, + match_objects: Optional[List[Dict[str, Any]]] = None, + max_chars: int = 0, + artifacts: Optional["CompileArtifacts"] = None, + pre_navigated_leaves: Optional[List[Any]] = None, + ) -> Optional[str]: + """Tree-guided evidence sampling: use compiled tree index to locate + relevant sections, then read precise char_range content. + + Falls back to None when no tree index is available, letting callers + use their default sampling strategy (rga windows, Monte Carlo, etc.). + + This method is designed to be called from both FAST and DEEP modes: + - FAST: called inside _rga_evidence() per-file loop + - DEEP: called before/alongside Monte Carlo sampling + + Args: + file_path: Absolute path to the target file. + query: User query for LLM-driven branch selection. + match_objects: Optional rga match objects for hybrid evidence. + max_chars: Character budget for this file's evidence. + Uses ``_FAST_MAX_EVIDENCE_CHARS`` when 0. + artifacts: Compile artifact context; when None, probes lazily. + pre_navigated_leaves: Pre-computed leaf nodes from a prior + ``navigate()`` call. When provided the method skips the + LLM navigation step (avoids duplicate LLM calls). + + Returns: + Formatted evidence string with tree-navigated sections, or None + when tree index is unavailable (caller should fall back). + """ + if max_chars <= 0: + max_chars = self._FAST_MAX_EVIDENCE_CHARS + + print(f"SEARCH_WIKI_DEBUG [S1] _tree_guided_sample: file_path={file_path}", flush=True) + + # --- Guard: tree availability --- + if artifacts is not None: + if file_path not in artifacts.tree_available_paths: + return None + else: + # Lazy probe when artifacts not provided (DEEP mode entry) + indexer = self._get_tree_indexer() + if indexer is None or not indexer.has_tree(file_path): + return None + + fname = Path(file_path).name + + # --- Obtain leaf nodes --- + leaves = pre_navigated_leaves + if leaves is None: + try: + indexer = self._get_tree_indexer() + if indexer is None: + return None + tree = indexer.load_tree(file_path) + if tree is None or tree.root is None: + return None + leaves = await indexer.navigate( + tree, query, + max_results=self._TREE_SAMPLE_MAX_SECTIONS, + ) + except Exception: + return None + + if not leaves: + return None + + # --- Classify leaves by extraction method --- + trimmed = leaves[: self._TREE_SAMPLE_MAX_SECTIONS] + page_leaves, char_leaves, table_and_summary = self._classify_leaves(trimmed) + print(f"SEARCH_WIKI_DEBUG [S2] classify_leaves: page={len(page_leaves)}, char={len(char_leaves)}, table_summary={len(table_and_summary)}", flush=True) + + # Collect (leaf, segment) pairs preserving original leaf order + leaf_segments: List[tuple] = [] # (leaf, segment_text) + + # -- Phase A: table / summary-only leaves -- + for leaf in table_and_summary: + leaf_segments.append((leaf, leaf.summary)) + + # -- Phase B: batch page-level extraction (single IO) -- + page_segment_map: dict = {} # id(leaf) -> segment + if page_leaves: + all_pages: set = set() + for _leaf, (sp, ep) in page_leaves: + all_pages.update(range(sp, ep + 1)) + try: + page_contents = DocumentExtractor.extract_pages( + file_path, sorted(all_pages), + ) + page_map = {pc.page_number: pc.content for pc in page_contents} + + for leaf, (sp, ep) in page_leaves: + seg_parts = [] + for p in range(sp, ep + 1): + text = page_map.get(p, "") + if text.strip(): + seg_parts.append(text) + if seg_parts: + page_segment_map[id(leaf)] = "\n".join(seg_parts) + elif getattr(leaf, 'summary', None): + page_segment_map[id(leaf)] = leaf.summary + except (FileNotFoundError, PermissionError): + raise # 文件系统错误应传播 + except Exception as e: + _loguru_logger.warning( + f"[TreeSample] Page extraction failed for {fname}: {e}, " + f"falling back to char_range for {len(page_leaves)} leaves" + ) + # Demote page_leaves → char_leaves + for leaf, _ in page_leaves: + if hasattr(leaf, 'char_range') and leaf.char_range: + char_leaves.append(leaf) + elif getattr(leaf, 'summary', None): + leaf_segments.append((leaf, leaf.summary)) + page_leaves_ok = False + else: + page_leaves_ok = True + + if page_leaves_ok: + for leaf, _ in page_leaves: + seg = page_segment_map.get(id(leaf)) + if seg: + leaf_segments.append((leaf, seg)) + # If page extraction failed, demoted leaves are now in char_leaves + + # -- Phase C: char_range extraction (compile-consistent content) -- + if char_leaves: + full_text = self._load_compile_content(self.work_path, file_path) + if not full_text: + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" + + for leaf in char_leaves: + start, end = leaf.char_range + if self._is_valid_char_range(start, end, len(full_text)) and full_text: + segment = full_text[start:end] + if segment.strip(): + leaf_segments.append((leaf, segment)) + elif getattr(leaf, 'summary', None): + leaf_segments.append((leaf, leaf.summary)) + elif getattr(leaf, 'summary', None): + _loguru_logger.debug( + f"[TreeSample] char_range degraded for '{leaf.title}' " + f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" + ) + leaf_segments.append((leaf, leaf.summary)) + + # --- Build parts with budget control --- + parts: List[str] = [] + total_chars = 0 + for leaf, segment in leaf_segments: + segment = segment[: self._TREE_SAMPLE_SECTION_MAX_CHARS] + if not segment.strip(): + continue + page_info = "" + if getattr(leaf, 'page_range', None): + ps, pe = leaf.page_range + page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" + type_tag = " [TABLE]" if getattr(leaf, 'content_type', 'text') == 'table' else "" + header = f"[{fname} \u2192 {leaf.title}{page_info}{type_tag}]" + chunk = f"{header}\n{segment}" + if total_chars + len(chunk) > max_chars: + remaining = max_chars - total_chars + if remaining > 200: + parts.append(chunk[:remaining]) + total_chars += remaining + break + parts.append(chunk) + total_chars += len(chunk) + + # --- Optional rga supplement --- + if ( + self._TREE_SAMPLE_RGA_SUPPLEMENT + and match_objects + and total_chars < max_chars + ): + hit_lines: List[int] = [] + for m in match_objects: + ln = m.get("data", {}).get("line_number") + if isinstance(ln, int): + hit_lines.append(ln) + if hit_lines: + ext = Path(file_path).suffix.lower() + if ext in self._FAST_TEXT_EXTENSIONS: + rga_ctx = self._read_context_windows( + file_path, hit_lines, + window=self._FAST_CONTEXT_WINDOW, + max_chars=max_chars - total_chars, + ) + if rga_ctx: + rga_section = f"[{fname} \u2192 rga hits]\n{rga_ctx}" + parts.append(rga_section) + total_chars += len(rga_section) + + if not parts: + return None + + evidence = "\n\n".join(parts) + print(f"SEARCH_WIKI_DEBUG [S3] _tree_guided_sample result: len={len(evidence) if evidence else 0}", flush=True) + await self._logger.info( + f"[TreeSample] {fname}: " + f"{len(parts)} sections, {total_chars} chars " + f"(pre_nav={'yes' if pre_navigated_leaves else 'no'})" + ) + return evidence + + async def _cached_navigate_tree( + self, + file_path: str, + query: str, + nav_cache: Dict[str, str], + ) -> Optional[str]: + """``_navigate_tree_for_evidence`` with per-query dedup cache.""" + cache_key = f"{file_path}::{query}" + if cache_key in nav_cache: + return nav_cache[cache_key] + result = await self._navigate_tree_for_evidence(file_path, query) + if isinstance(result, str) and result.strip(): + nav_cache[cache_key] = result + return result + + async def _collect_deep_tree_evidence( + self, + file_paths: List[str], + query: str, + *, + scope: Optional["_PathScope"] = None, + nav_cache: Optional[Dict[str, str]] = None, + ) -> Dict[str, str]: + """Full tree navigation for DEEP mode primary files. + + Runs ``_navigate_tree_for_evidence`` (complement nav, table supplement, + referenced-page gap-fill) on each file. Returns a dict mapping + file paths to raw evidence text. The evidence is used to + **supplement** (not replace) Monte Carlo sampling. + + When *scope* is provided, only files within the search path scope + are navigated — prevents cross-document evidence contamination. + When *nav_cache* is provided, results are cached to avoid + duplicate navigation across pipeline phases. + """ + indexer = self._get_tree_indexer() + if indexer is None: + return {} + + if scope: + file_paths = [fp for fp in file_paths if scope.contains(fp)] + if not file_paths: + return {} + + nav_fps = [fp for fp in file_paths[:self._DEEP_PRE_NAV_MAX_FILES] + if indexer.has_tree(fp)] + if not nav_fps: + return {} + + if nav_cache is not None: + results = await asyncio.gather( + *[self._cached_navigate_tree(fp, query, nav_cache) for fp in nav_fps], + return_exceptions=True, + ) + else: + results = await asyncio.gather( + *[self._navigate_tree_for_evidence(fp, query) for fp in nav_fps], + return_exceptions=True, + ) + + evidence_dict: Dict[str, str] = {} + for fp, res in zip(nav_fps, results): + if isinstance(res, Exception): + await self._logger.warning( + f"[Phase 2.5:DirectTree] Navigation failed for " + f"{Path(fp).name}: {res}" + ) + elif isinstance(res, str) and res.strip(): + evidence_dict[fp] = res + + if evidence_dict: + total_len = sum(len(v) for v in evidence_dict.values()) + await self._logger.info( + f"[Phase 2.5:DirectTree] {len(evidence_dict)} files, " + f"{total_len} chars" + ) + return evidence_dict + + @classmethod + def _classify_leaves(cls, leaves: list) -> Tuple[List[tuple], List, List]: + """Classify leaf nodes by preferred extraction strategy. + + For non-table leaves, **char_range** (kreuzberg markdown) is preferred + over page_range (pypdf raw text) because compile-time extraction + preserves table layout and column structure far better than pypdf's + ``extract_text()``. page_range remains available on each leaf for + table-supplement filtering even when the leaf is routed to char_leaves. + + Thin char_range nodes (span < ``_CHAR_RANGE_MIN_SPAN``) are demoted + to page-level extraction when a valid page_range exists, as they + typically represent TOC entries whose char offsets only cover the + section title rather than the actual content. + + Returns: + (page_leaves, char_leaves, summary_leaves) triple: + - page_leaves: list of (leaf, page_range) — page-level extraction + - char_leaves: list of leaf — kreuzberg char_range extraction + - summary_leaves: list of leaf — only summary available + """ + page_leaves: List[tuple] = [] + char_leaves: List = [] + summary_leaves: List = [] + min_span = cls._CHAR_RANGE_MIN_SPAN + + for leaf in leaves: + # Table nodes: prefer page-level extraction for raw original content + if getattr(leaf, 'content_type', 'text') == 'table': + page_range = getattr(leaf, 'page_range', None) + if ( + page_range + and len(page_range) == 2 + and page_range[0] is not None + and page_range[0] > 0 + ): + page_leaves.append((leaf, page_range)) + elif getattr(leaf, 'summary', None): + summary_leaves.append(leaf) + else: + char_leaves.append(leaf) + continue + + # Non-table leaves: prefer char_range (kreuzberg markdown) over + # page_range (pypdf raw text) for higher-fidelity table rendering. + has_char = hasattr(leaf, 'char_range') and leaf.char_range + page_range = getattr(leaf, 'page_range', None) + has_page = ( + page_range + and len(page_range) == 2 + and page_range[0] is not None + and page_range[0] > 0 + ) + + if has_char: + start, end = leaf.char_range + span = end - start if end > start else 0 + if span < min_span and has_page: + page_leaves.append((leaf, page_range)) + else: + char_leaves.append(leaf) + elif has_page: + page_leaves.append((leaf, page_range)) + elif getattr(leaf, 'summary', None): + summary_leaves.append(leaf) + + return page_leaves, char_leaves, summary_leaves + + def _is_valid_char_range( + self, start: int, end: int, text_len: int, + ) -> bool: + """Check whether a char_range is valid for slicing. + + A range is invalid when it covers more than + ``_CHAR_RANGE_MAX_SPAN_RATIO`` of the document (likely a + whole-document fallback) or when *end <= start*. + """ + if start < 0 or end <= start or text_len <= 0: + return False + span_ratio = (end - start) / text_len + return span_ratio < self._CHAR_RANGE_MAX_SPAN_RATIO + + @staticmethod + def _is_evidence_sufficient(evidence: str, min_chars: int = 0) -> bool: + """Check whether collected evidence has enough substance to answer a query. + + Uses a length threshold as a lightweight, domain-agnostic proxy. + Empty or near-empty evidence (e.g., only headers with no data) + fails the check, triggering a retry with expanded parameters. + """ + if not evidence: + return False + stripped = evidence.strip() + return len(stripped) >= min_chars + + _MULTI_COMPONENT_PATTERNS: Tuple[Tuple[str, ...], ...] = ( + ("balance sheet", "income statement"), + ("balance sheet", "cash flow"), + ("income statement", "cash flow"), + ("accounts payable", "cost of"), + ("accounts payable", "inventory"), + ("current assets", "current liabilities"), + ("revenue", "net income", "earnings"), + ("operating income", "depreciation"), + ) + + @staticmethod + def _decompose_query_components(query: str) -> List[str]: + """Extract distinct data-source components from a multi-part query. + + Scans for known multi-component patterns (e.g. a ratio needing data + from both Balance Sheet and Income Statement) and returns a list of + component phrases that the evidence should cover. + """ + q = query.lower() + components: List[str] = [] + for group in AgenticSearch._MULTI_COMPONENT_PATTERNS: + hits = [phrase for phrase in group if phrase in q] + if len(hits) >= 2: + components.extend(hits) + if not components: + financial_keywords = [ + "balance sheet", "income statement", "cash flow", + "accounts payable", "accounts receivable", "inventory", + "current liabilities", "current assets", "total assets", + "revenue", "cost of", "cogs", "depreciation", "amortization", + "operating income", "net income", "earnings", + ] + for kw in financial_keywords: + if kw in q: + components.append(kw) + seen: set = set() + return [c for c in components if not (c in seen or seen.add(c))] + + @staticmethod + def _check_leaf_coverage( + leaves: list, components: List[str], + ) -> Tuple[List[str], List[str]]: + """Check which query components are covered by the navigated leaves. + + Returns: + (covered, missing) — lists of component phrases. + """ + if not leaves or not components: + return [], list(components) + leaf_text = " ".join( + (getattr(l, 'title', '') or '') + " " + (getattr(l, 'summary', '') or '') + for l in leaves + ).lower() + covered = [c for c in components if c in leaf_text] + missing = [c for c in components if c not in leaf_text] + return covered, missing + + @staticmethod + def _extract_referenced_pages(text: str) -> Set[int]: + """Extract page numbers referenced in evidence text. + + Detects cross-references like 'page 60', 'pages 45-47', 'pp. 12-15' + that hint at data-bearing pages not yet included in evidence. + """ + pages: Set[int] = set() + for m in re.finditer( + r"\b(?:pages?|pp?\.)\s*(\d+)\s*[-\u2013]\s*(\d+)", + text, re.IGNORECASE, + ): + start, end = int(m.group(1)), int(m.group(2)) + if 0 < start <= end and end - start <= 10: + pages.update(range(start, end + 1)) + for m in re.finditer( + r"\b(?:pages?|pp?\.)\s*(\d+)\b", text, re.IGNORECASE, + ): + p = int(m.group(1)) + if 0 < p <= 500: + pages.add(p) + return pages + + @staticmethod + def _load_compile_content( + work_path: Path, file_path: str, + ) -> Optional[str]: + """Load the ENHANCED content cached at compile time. + + Compile stores the kreuzberg ENHANCED-profile content alongside the + tree index so that search-time ``char_range`` slicing operates on + the *same* text the ranges were computed from. Returns ``None`` + when the cache file is missing (e.g. pre-cache compile run). + """ + try: + from sirchmunk.utils.file_utils import get_fast_hash + file_hash = get_fast_hash(file_path) + if not file_hash: + return None + cache_path = ( + work_path / ".cache" / "compile" / "content" / f"{file_hash}.txt" + ) + if cache_path.exists(): + return cache_path.read_text(encoding="utf-8") + except Exception: + pass + return None + + @staticmethod + def _load_table_digest( + work_path: Path, file_hash: str, + ) -> Optional[List[Dict[str, Any]]]: + """Load pre-compiled table digest for a file. + + Returns the list of table entries from the digest JSON, or None + if no digest exists or loading fails. + """ + digest_path = ( + work_path / ".cache" / "compile" / "table_digests" / f"{file_hash}.json" + ) + if not digest_path.exists(): + return None + try: + data = json.loads(digest_path.read_text(encoding="utf-8")) + return data.get("tables", []) + except Exception: + return None + + @staticmethod + def _filter_tables_by_page_range( + tables: List[Dict[str, Any]], + page_start: int, + page_end: int, + ) -> List[Dict[str, Any]]: + """Filter tables whose page_number falls within the given range (inclusive).""" + return [ + t for t in tables + if t.get("page_number") is not None + and page_start <= t["page_number"] <= page_end + ] + + _TABLE_RELEVANCE_MIN_PREFIX = 5 + _TABLE_STRUCTURE_BONUS: float = 0.25 + """Bonus score for tables exhibiting structured data characteristics + (high row count, numeric density). Applied additively to the keyword + relevance score so that data-rich tables are preferred when keyword + scores tie.""" + _TABLE_STRUCTURE_MIN_ROWS: int = 5 + """Minimum ``|``-delimited rows for a table to qualify for the + structure bonus.""" + _TABLE_STRUCTURE_MIN_NUMERIC_RATIO: float = 0.15 + """Minimum ratio of numeric tokens to total tokens for the bonus.""" + + @staticmethod + def _score_table_relevance( + markdown: str, query_tokens: frozenset, + ) -> float: + """Score a table's relevance to the query via token overlap. + + Uses two matching strategies per token: + + 1. **Exact substring** — fast check whether the token appears + anywhere in the table text (original behaviour). + 2. **Prefix match** — handles morphological variation such as + plural/singular (*inventory* ↔ *inventories*) by comparing + word prefixes of at least ``_TABLE_RELEVANCE_MIN_PREFIX`` + characters. Only attempted when the exact match misses. + + Returns a value in [0, 1] representing the fraction of + *query_tokens* matched. + """ + if not markdown or not query_tokens: + return 0.0 + + min_pfx = AgenticSearch._TABLE_RELEVANCE_MIN_PREFIX + md_lower = markdown.lower() + md_words = None # lazily built on first prefix-match attempt + + hits = 0 + for tok in query_tokens: + if tok in md_lower: + hits += 1 + continue + # Prefix-match fallback + pfx_len = min(len(tok), min_pfx) + if pfx_len < 4: + continue + if md_words is None: + md_words = frozenset(md_lower.split()) + prefix = tok[:pfx_len] + if any( + w[:pfx_len] == prefix + for w in md_words + if len(w) >= pfx_len + ): + hits += 1 + + return hits / len(query_tokens) + + @staticmethod + def _score_table_structure(markdown: str) -> float: + """Score a table's structural richness (row count + numeric density). + + Data-dense tables (financial statements, balance sheets) score + higher than narrative paragraphs that happen to contain a small + embedded table. The score is in [0, 1] and is added as a bonus + to the keyword relevance score during table ranking. + """ + if not markdown: + return 0.0 + + rows = markdown.count("\n") + if rows < AgenticSearch._TABLE_STRUCTURE_MIN_ROWS: + return 0.0 + + tokens = markdown.split() + if not tokens: + return 0.0 + + numeric_count = sum( + 1 for t in tokens + if any(c.isdigit() for c in t) + ) + numeric_ratio = numeric_count / len(tokens) + + if numeric_ratio < AgenticSearch._TABLE_STRUCTURE_MIN_NUMERIC_RATIO: + return 0.0 + + row_score = min(rows / 30.0, 1.0) + num_score = min(numeric_ratio / 0.4, 1.0) + return (row_score * 0.5 + num_score * 0.5) + + @staticmethod + def _deduplicate_table_sections( + primary_ev: str, secondary_ev: str, + ) -> str: + """Remove table sections from *secondary_ev* whose pages already + appear in *primary_ev*. + + Matching is based on ``[Table from page N]`` and ``[Tables pp.X-Y]`` + headers. Non-table content in *secondary_ev* is preserved intact. + """ + if not primary_ev or not secondary_ev: + return secondary_ev + + covered: Set[int] = { + int(m.group(1)) + for m in re.finditer(r"\[Table from page (\d+)\]", primary_ev) + } + for m in re.finditer(r"\[Tables pp\.(\d+)-(\d+)\]", primary_ev): + covered.update(range(int(m.group(1)), int(m.group(2)) + 1)) + + if not covered: + return secondary_ev + + blocks = secondary_ev.split("\n\n") + kept: List[str] = [] + for block in blocks: + page_m = re.search(r"\[Table from page (\d+)\]", block) + if page_m and int(page_m.group(1)) in covered: + continue + kept.append(block) + + result = "\n\n".join(kept) + return result if result.strip() else "" + + @staticmethod + def _format_table_evidence( + tables: List[Dict[str, Any]], + max_chars: int = 20_000, + query: str = "", + ) -> str: + """Format table digest entries as LLM-friendly evidence text. + + When *query* is provided, tables are **sorted by relevance** to the + query before budget truncation, ensuring critical tables are included + even when they appear late in page order. + + Strategy: + - Query-relevant tables are prioritised via keyword overlap scoring + - Each table prefixed with "[Table from page N]" + - Large tables truncated with "(truncated)" note + + Returns concatenated formatted table evidence string. + """ + if not tables: + return "" + + ordered = tables + if query: + query_tokens = frozenset( + tok for tok in query.lower().split() if len(tok) >= 2 + ) + if query_tokens: + struct_bonus = AgenticSearch._TABLE_STRUCTURE_BONUS + scored = [ + ( + AgenticSearch._score_table_relevance( + t.get("markdown", ""), query_tokens, + ) + + struct_bonus * AgenticSearch._score_table_structure( + t.get("markdown", ""), + ), + idx, + t, + ) + for idx, t in enumerate(tables) + ] + scored.sort(key=lambda x: (-x[0], x[1])) + ordered = [t for _, _, t in scored] + + parts: List[str] = [] + remaining = max_chars + + for table in ordered: + if remaining <= 0: + break + + page = table.get("page_number", "?") + markdown = table.get("markdown", "") + + if not markdown: + continue + + header = f"[Table from page {page}]" + + if len(markdown) <= remaining: + parts.append(f"{header}\n{markdown}") + remaining -= len(markdown) + len(header) + 2 + else: + truncated = markdown[:remaining] + parts.append(f"{header}\n{truncated}\n(truncated)") + remaining = 0 + + return "\n\n".join(parts) + + @staticmethod + def _append_evidence_part( + parts: List[str], fname: str, leaf, segment: str, + *, max_chars: int = 3000, + ) -> None: + """Format and append one leaf's evidence to *parts* (in-place).""" + text = segment[:max_chars] + if not text.strip(): + return + page_info = "" + if getattr(leaf, 'page_range', None): + ps, pe = leaf.page_range + page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" + type_tag = " [TABLE]" if getattr(leaf, 'content_type', 'text') == 'table' else "" + header = f"[{fname} \u2192 {leaf.title}{page_info}{type_tag}]" + parts.append(f"{header}\n{text}") + + async def _navigate_tree_for_evidence( + self, + file_path: str, + query: str, + *, + max_results: int = 8, + match_objects: Optional[List[Dict[str, Any]]] = None, + ) -> Optional[str]: + """LLM-driven tree navigation: select relevant sections and read leaf content. + + Uses 1 LLM call to drill into the compiled tree index for + *file_path*, returning concatenated leaf content as evidence. + Returns None when no tree cache is available. + + When *match_objects* (RGA hit dicts) are provided, keyword-level + context windows are appended as supplementary evidence after tree + navigation, fusing structural and keyword signals. + + Extraction priority (highest first): + 1. char_range – compile-time ENHANCED content slice (preserves tables) + 2. page_range – page-level extraction via DocumentExtractor (fallback) + 3. leaf.summary – last resort + """ + indexer = self._get_tree_indexer() + print(f"SEARCH_WIKI_DEBUG [N1] _navigate_tree_for_evidence: file_path={file_path}", flush=True) + if indexer is None: + return None + tree = indexer.load_tree(file_path) + if tree is None or tree.root is None: + return None + + try: + leaves = await indexer.navigate(tree, query, max_results=max_results) + except Exception: + return None + + print(f"SEARCH_WIKI_DEBUG [N2] navigate_result: {len(leaves) if leaves else 0} leaves", flush=True) + + if not leaves: + return None + + fname = Path(file_path).name + parts: List[str] = [] + + # ── Phase 1: classify leaves by available extraction method ── + page_leaves, char_leaves, summary_only = self._classify_leaves(leaves) + print(f"SEARCH_WIKI_DEBUG [N3] classify_leaves: page={len(page_leaves)}, char={len(char_leaves)}, summary={len(summary_only)}", flush=True) + + for leaf in summary_only: + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + + # ── Phase 2: batch page-level extraction (single IO) ── + if page_leaves: + all_pages: set = set() + for _leaf, (sp, ep) in page_leaves: + all_pages.update(range( + max(1, sp - self._NAV_PAGE_MARGIN), + ep + self._NAV_PAGE_MARGIN + 1, + )) + try: + page_contents = DocumentExtractor.extract_pages( + file_path, sorted(all_pages), + ) + page_map = {pc.page_number: pc.content for pc in page_contents} + + for leaf, (sp, ep) in page_leaves: + segment_parts = [] + for p in range(sp, ep + 1): + text = page_map.get(p, "") + if text.strip(): + segment_parts.append(text) + if segment_parts: + self._append_evidence_part( + parts, fname, leaf, "\n".join(segment_parts), + ) + elif getattr(leaf, 'summary', None): + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + except (FileNotFoundError, PermissionError): + raise # 文件系统错误应传播 + except Exception as e: + _loguru_logger.warning( + f"[TreeNav] Page extraction failed for {fname}: {e}, " + f"falling back to char_range for {len(page_leaves)} leaves" + ) + # Demote page_leaves → char_leaves for char_range fallback + for leaf, _ in page_leaves: + if hasattr(leaf, 'char_range') and leaf.char_range: + char_leaves.append(leaf) + elif getattr(leaf, 'summary', None): + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + print(f"SEARCH_WIKI_DEBUG [N4] page_extraction: page_leaves_ok=False", flush=True) + else: + print(f"SEARCH_WIKI_DEBUG [N4] page_extraction: page_leaves_ok=True", flush=True) + + # ── Phase 3: char_range extraction (compile-consistent content) ── + if char_leaves: + # Prefer compile-time ENHANCED content (matches char_range offsets + # exactly). Fall back to fast_extract only when cache is absent. + full_text = self._load_compile_content(self.work_path, file_path) + if not full_text: + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" + + # Leaves whose char_range is invalid but have a valid page_range + # are demoted to page extraction instead of discarding to summary. + page_fallback_leaves: List[tuple] = [] + + for leaf in char_leaves: + start, end = leaf.char_range + if self._is_valid_char_range(start, end, len(full_text)) and full_text: + segment = full_text[start:end] + if segment.strip(): + self._append_evidence_part( + parts, fname, leaf, segment, + ) + elif getattr(leaf, 'summary', None): + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + else: + # char_range covers too much of the document (or text is + # empty). Try page_range extraction before falling back + # to summary. + pr = getattr(leaf, 'page_range', None) + if ( + pr + and len(pr) == 2 + and pr[0] is not None + and pr[0] > 0 + ): + page_fallback_leaves.append((leaf, pr)) + elif getattr(leaf, 'summary', None): + _loguru_logger.debug( + f"[TreeNav] char_range degraded for '{leaf.title}' " + f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), " + f"using summary" + ) + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + + # Batch page extraction for demoted leaves (same pattern as Phase 2) + if page_fallback_leaves: + all_fb_pages: set = set() + for _lf, (sp, ep) in page_fallback_leaves: + all_fb_pages.update(range( + max(1, sp - self._NAV_PAGE_MARGIN), + ep + self._NAV_PAGE_MARGIN + 1, + )) + try: + fb_contents = DocumentExtractor.extract_pages( + file_path, sorted(all_fb_pages), + ) + fb_map = {pc.page_number: pc.content for pc in fb_contents} + for lf, (sp, ep) in page_fallback_leaves: + seg_parts = [ + fb_map[p] for p in range(sp, ep + 1) + if fb_map.get(p, "").strip() + ] + if seg_parts: + self._append_evidence_part( + parts, fname, lf, "\n".join(seg_parts), + ) + elif getattr(lf, 'summary', None): + self._append_evidence_part( + parts, fname, lf, lf.summary, + ) + except Exception: + for lf, _ in page_fallback_leaves: + if getattr(lf, 'summary', None): + self._append_evidence_part( + parts, fname, lf, lf.summary, + ) + + # ── Phase 4: Complementary navigation for multi-component queries ── + # When a query requires data from multiple document sections (e.g. + # Balance Sheet + Income Statement for a ratio), the initial navigate + # may only reach one component. Detect missing components and run a + # focused second navigate pass with a refined query. + _query_components = self._decompose_query_components(query) + if len(_query_components) >= self._NAV_COMPLEMENT_MIN_COMPONENTS: + _covered, _missing = self._check_leaf_coverage(leaves, _query_components) + if _missing: + _complement_query = f"{query} — focus on: {', '.join(_missing)}" + try: + _existing_ids = {id(l) for l in leaves} + comp_leaves = await indexer.navigate( + tree, _complement_query, max_results=max_results, + ) + comp_new = [l for l in (comp_leaves or []) if id(l) not in _existing_ids] + if comp_new: + c_page, c_char, c_summary = self._classify_leaves(comp_new) + for cl in c_summary: + self._append_evidence_part(parts, fname, cl, cl.summary) + if c_page: + c_all_pages: set = set() + for _cl, (csp, cep) in c_page: + c_all_pages.update(range(csp, cep + 1)) + try: + c_contents = DocumentExtractor.extract_pages( + file_path, sorted(c_all_pages), + ) + c_map = {pc.page_number: pc.content for pc in c_contents} + for cl, (csp, cep) in c_page: + c_seg = [c_map[p] for p in range(csp, cep + 1) if c_map.get(p, "").strip()] + if c_seg: + self._append_evidence_part(parts, fname, cl, "\n".join(c_seg)) + except Exception: + pass + if c_char: + c_text = self._load_compile_content(self.work_path, file_path) or "" + for cl in c_char: + s, e = cl.char_range + if self._is_valid_char_range(s, e, len(c_text)) and c_text: + seg = c_text[s:e] + if seg.strip(): + self._append_evidence_part(parts, fname, cl, seg) + leaves = list(leaves) + comp_new + print( + f"SEARCH_WIKI_DEBUG [N3.2] complement_nav: " + f"missing={_missing}, new_leaves={len(comp_new)}", + flush=True, + ) + except Exception: + pass + + # ── Plan 3: Retry with expanded results if evidence is insufficient ── + # Triggers on: (a) zero evidence parts, OR (b) evidence too thin. + _current_ev_text = "\n\n".join(parts) + _needs_retry = ( + max_results < self._NAV_RETRY_EXPANDED_RESULTS + and not self._is_evidence_sufficient( + _current_ev_text, self._NAV_RETRY_MIN_EVIDENCE_CHARS, + ) + ) + if _needs_retry: + try: + retry_leaves = await indexer.navigate( + tree, query, + max_results=self._NAV_RETRY_EXPANDED_RESULTS, + ) + if retry_leaves: + r_page, r_char, r_summary = self._classify_leaves(retry_leaves) + for rl in r_summary: + self._append_evidence_part(parts, fname, rl, rl.summary) + + # Page-level extraction for retry (mirrors Phase 2) + if r_page: + r_all_pages: set = set() + for _rl, (rsp, rep) in r_page: + r_all_pages.update(range(rsp, rep + 1)) + try: + r_page_contents = DocumentExtractor.extract_pages( + file_path, sorted(r_all_pages), + ) + r_page_map = {pc.page_number: pc.content for pc in r_page_contents} + for rl, (rsp, rep) in r_page: + r_seg = [r_page_map[p] for p in range(rsp, rep + 1) if r_page_map.get(p, "").strip()] + if r_seg: + self._append_evidence_part(parts, fname, rl, "\n".join(r_seg)) + except Exception: + pass + + # Char-range extraction for retry (mirrors Phase 3) + if r_char: + r_text = self._load_compile_content(self.work_path, file_path) or "" + for rl in r_char: + s, e = rl.char_range + if self._is_valid_char_range(s, e, len(r_text)) and r_text: + seg = r_text[s:e] + if seg.strip(): + self._append_evidence_part(parts, fname, rl, seg) + + leaves = retry_leaves + print(f"SEARCH_WIKI_DEBUG [N3.1] retry_nav: {len(retry_leaves)} leaves", flush=True) + except Exception: + pass + + if not parts: + return None + + # Supplement with table evidence if available + _all_tables = None + try: + from sirchmunk.utils.file_utils import get_fast_hash + _file_hash = get_fast_hash(file_path) + if _file_hash: + _all_tables = self._load_table_digest( + self.work_path, _file_hash, + ) + if _all_tables and leaves: + _seen_pages: set = set() + for leaf in leaves: + if leaf.page_range: + ps, pe = leaf.page_range + page_key = (ps, pe) + if page_key in _seen_pages: + continue + _seen_pages.add(page_key) + leaf_tables = self._filter_tables_by_page_range( + _all_tables, ps, pe, + ) + if leaf_tables: + table_text = self._format_table_evidence( + leaf_tables, + max_chars=self._TABLE_EVIDENCE_PER_RANGE_CHARS, + query=query, + ) + if table_text: + parts.append( + f"[Tables pp.{ps}-{pe}]\n{table_text}" + ) + except Exception: + pass + + # ── Phase 5.5: Cross-section table supplement (conditional) ── + # Only supplements when existing evidence is below threshold + # to prevent evidence overload for queries already well-served. + _current_ev_len = sum(len(p) for p in parts) + if _all_tables and leaves and _current_ev_len < self._DEEP_CROSS_SECTION_MIN_EVIDENCE: + _leaf_page_set: Set[int] = set() + for _lf in leaves: + _pr = getattr(_lf, "page_range", None) + if _pr and len(_pr) == 2 and _pr[0] is not None: + _leaf_page_set.update(range( + max(1, _pr[0] - self._NAV_PAGE_MARGIN), + _pr[1] + self._NAV_PAGE_MARGIN + 1, + )) + _cross_tables = [ + t for t in _all_tables + if t.get("page_number") is not None + and t["page_number"] not in _leaf_page_set + ] + if _cross_tables: + _cross_ev = self._format_table_evidence( + _cross_tables, + max_chars=self._TABLE_CROSS_SECTION_CHARS, + query=query, + ) + if _cross_ev: + parts.append( + f"[{fname} - Cross-section Tables]\n{_cross_ev}" + ) + print( + f"SEARCH_WIKI_DEBUG [N5.3] cross_section_tables: " + f"uncovered_tables={len(_cross_tables)}, " + f"ev_len={len(_cross_ev)}", + flush=True, + ) + + # Plan 3: If evidence is still too thin, add full table digest as standalone + evidence = "\n\n".join(parts) + if ( + not self._is_evidence_sufficient( + evidence, self._NAV_RETRY_MIN_EVIDENCE_CHARS, + ) + and _all_tables + ): + standalone_table_ev = self._format_table_evidence( + _all_tables, + max_chars=self._TABLE_EVIDENCE_STANDALONE_CHARS, + query=query, + ) + if standalone_table_ev: + parts.append( + f"[{fname} - Standalone Table Evidence]\n{standalone_table_ev}" + ) + evidence = "\n\n".join(parts) + print(f"SEARCH_WIKI_DEBUG [N5.1] standalone_table_fallback: len={len(standalone_table_ev)}", flush=True) + + print(f"SEARCH_WIKI_DEBUG [N5] table_supplement: tables_loaded={len(_all_tables) if _all_tables else 0}", flush=True) + + # ── Phase 6: Referenced-page gap-fill ── + # Scan evidence for page cross-references (e.g. TOC entries + # pointing to financial statements) and extract any that were + # not covered by the navigated leaves. + if parts: + _covered_pages: Set[int] = set() + for leaf in leaves: + pr = getattr(leaf, "page_range", None) + if pr and len(pr) == 2 and pr[0] is not None: + _covered_pages.update(range( + max(1, pr[0] - self._NAV_PAGE_MARGIN), + pr[1] + self._NAV_PAGE_MARGIN + 1, + )) + _referenced = self._extract_referenced_pages("\n\n".join(parts)) + _gap_pages = sorted(_referenced - _covered_pages)[ + : self._NAV_REF_PAGE_MAX + ] + if _gap_pages: + try: + _gap_contents = DocumentExtractor.extract_pages( + file_path, _gap_pages, + ) + for pc in _gap_contents: + if pc.content and pc.content.strip(): + parts.append( + f"[{fname} \u2192 referenced p.{pc.page_number}]" + f"\n{pc.content}" + ) + evidence = "\n\n".join(parts) + print( + f"SEARCH_WIKI_DEBUG [N5.2] ref_page_gap_fill: " + f"pages={_gap_pages}", + flush=True, + ) + except Exception: + pass + + # --- RGA keyword supplement: fuse keyword hits into tree evidence --- + if match_objects: + _ev_len = sum(len(p) for p in parts) + _rga_budget = max(0, self._FAST_MAX_EVIDENCE_CHARS - _ev_len) + if _rga_budget > 200: + hit_lines: List[int] = [ + m.get("data", {}).get("line_number") + for m in match_objects + if isinstance(m.get("data", {}).get("line_number"), int) + ] + ext = Path(file_path).suffix.lower() + rga_ctx: Optional[str] = None + if ext in self._FAST_TEXT_EXTENSIONS and hit_lines: + rga_ctx = self._read_context_windows( + file_path, hit_lines, + window=self._FAST_CONTEXT_WINDOW, + max_chars=_rga_budget, + ) + else: + snippet_parts: List[str] = [] + snippet_total = 0 + for m in match_objects: + text = m.get("data", {}).get("lines", {}).get("text", "").rstrip() + if text and snippet_total + len(text) < _rga_budget: + snippet_parts.append(text) + snippet_total += len(text) + if snippet_parts: + rga_ctx = "\n".join(snippet_parts) + if rga_ctx: + parts.append(f"[{fname} \u2192 keyword hits]\n{rga_ctx}") + evidence = "\n\n".join(parts) + + print(f"SEARCH_WIKI_DEBUG [N6] _navigate_tree_for_evidence result: len={len(evidence) if evidence else 0}", flush=True) + await self._logger.info( + f"[FAST:TreeNav] Extracted {len(parts)} sections, " + f"{len(evidence)} chars from {fname}" + ) + return evidence + + async def _fast_self_correct( + self, + query: str, + best_files: Optional[List[Dict[str, Any]]], + catalog_routed_files: List[str], + context: SearchContext, + ) -> Optional[str]: + """Attempt to gather alternative evidence when the first answer is rejected. + + Four strategies tried in order: + D) Re-sample the same primary file with expanded parameters (deeper sampling). + A) Tree-navigate a 2nd catalog-routed file not yet tried. + B) Retrieve the most semantically similar compiled cluster's content. + C) Tree-navigate the 2nd-best rga file if available. + + Returns alternative evidence string, or None if all strategies fail. + """ + first_file = best_files[0]["path"] if best_files else "" + + # Strategy D: Re-sample the SAME primary file with expanded parameters. + # The file was correct but the initial sampling may have missed key sections. + if first_file: + expanded_tree_ev = await self._navigate_tree_for_evidence( + first_file, query, + max_results=self._SELF_CORRECT_EXPANDED_NAV_RESULTS, + ) + if expanded_tree_ev and len(expanded_tree_ev.strip()) > 50: + await self._logger.info( + "[FAST:SelfCorrect] Strategy D succeeded: " + "expanded same-file tree navigation" + ) + return expanded_tree_ev + + # Strategy A: 2nd catalog-routed file via tree navigation + for fp in catalog_routed_files: + if fp == first_file: + continue + tree_ev = await self._navigate_tree_for_evidence(fp, query) + if tree_ev and len(tree_ev.strip()) > 50: + context.mark_file_read(fp) + return tree_ev + + # Strategy B: cluster content from knowledge storage + if self.embedding_client and self.knowledge_storage: + try: + qe = self.embedding_client.encode(query) + if qe is not None: + vec = qe.tolist() if hasattr(qe, "tolist") else list(qe) + hits = await self.knowledge_storage.search_similar_clusters( + query_embedding=vec, top_k=2, similarity_threshold=0.50, + ) + if hits: + parts: List[str] = [] + for h in hits[:2]: + c = await self.knowledge_storage.get(h["id"]) + if c and c.content: + parts.append(str(c.content)[:3000]) + for ev in (c.evidences or [])[:3]: + for s in (ev.snippets or [])[:2]: + parts.append(s[:500]) + if parts: + return "\n\n---\n\n".join(parts) + except Exception: + pass + + # Strategy C: 2nd rga file via tree navigation + if best_files and len(best_files) > 1: + fp2 = best_files[1]["path"] + tree_ev = await self._navigate_tree_for_evidence(fp2, query) + if tree_ev and len(tree_ev.strip()) > 50: + context.mark_file_read(fp2) + return tree_ev + + return None + + @staticmethod + def _parse_fast_json(text: str) -> Dict[str, Any]: + """Extract JSON from the FAST query analysis LLM response.""" + text = text.strip() + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError): + pass + cleaned = re.sub(r"^```(?:json)?\s*", "", text, flags=re.MULTILINE) + cleaned = re.sub(r"```\s*$", "", cleaned, flags=re.MULTILINE).strip() + try: + return json.loads(cleaned) + except (json.JSONDecodeError, TypeError): + pass + match = re.search(r"\{.*\}", text, re.DOTALL) + if match: + try: + return json.loads(match.group()) + except (json.JSONDecodeError, TypeError): + pass + return {} + + # ------------------------------------------------------------------ + # Phase 1 probes (each designed to run concurrently) + # ------------------------------------------------------------------ + + async def _probe_keywords( + self, query: str, + ) -> Tuple[Dict[str, float], List[str]]: + """Extract multi-level keywords from the query via LLM. + + Also extracts cross-lingual alternative keywords from the + ```` block and merges them into the result list. + + Additionally synthesises rga-friendly compound phrases from + Level 1 keywords so that downstream ``_retrieve_by_keywords`` + tries exact multi-word matches before falling back to atomic + terms (mirrors the strategy used by FAST mode). + + Returns: + Tuple of (keyword_idf_dict, keyword_list). + """ + await self._logger.info("[Probe:Keywords] Extracting keywords...") + dynamic_prompt = generate_keyword_extraction_prompt(num_levels=2) + keyword_prompt = dynamic_prompt.replace(KEYWORD_QUERY_PLACEHOLDER, query) + kw_response = await self.llm.achat( + messages=[{"role": "user", "content": keyword_prompt}], + stream=False, + ) + self.llm_usages.append(kw_response.usage) + + keyword_sets = self._extract_and_validate_multi_level_keywords( + kw_response.content, num_levels=2, + ) + + alt_keywords = self._extract_alt_keywords(kw_response.content) + if alt_keywords: + await self._logger.info(f"[Probe:Keywords] Cross-lingual alt: {list(alt_keywords.keys())}") + + for kw_set in keyword_sets: + if kw_set: + merged = {**kw_set, **alt_keywords} + # Synthesise rga-friendly compound phrases: promote + # multi-word Level-1 keywords to the front with boosted + # IDF so _retrieve_by_keywords tries them first as exact + # phrases (similar to FAST's primary/fallback strategy). + compound_phrases: Dict[str, float] = {} + atomic_terms: Dict[str, float] = {} + for kw, idf in merged.items(): + if " " in kw.strip() and len(kw.split()) >= 2: + compound_phrases[kw] = max(idf, 7.0) + else: + atomic_terms[kw] = idf + # Compounds first, then atomics — preserves ordering for + # _retrieve_by_keywords which iterates keywords in order. + ordered = {**compound_phrases, **atomic_terms} + kw_list = list(ordered.keys()) + await self._logger.info( + f"[Probe:Keywords] Extracted: {kw_list} " + f"(compounds={len(compound_phrases)})" + ) + return ordered, kw_list + + if alt_keywords: + return alt_keywords, list(alt_keywords.keys()) + + return {}, [] + + @staticmethod + def _has_directory_paths(paths: List[str]) -> bool: + """Return True if any element in *paths* is a directory.""" + return any(Path(p).is_dir() for p in paths) + + @staticmethod + def _resolve_file_hints( + paths: List[str], + file_hints: List[str], + max_depth: int = 8, + ) -> List[str]: + """Resolve file_hints (filenames) to absolute paths under *paths*. + + Lightweight name-only search: no metadata extraction. Used when the + user clearly asks for a specific document (e.g. "总结《foo.pdf》") + so we can skip full dir scan + LLM rank. + + Returns: + List of absolute path strings that match any hint (deduplicated, + order preserved). Empty if no matches. + """ + if not file_hints: + return [] + + hints = [h.strip() for h in file_hints if (h and isinstance(h, str))] + if not hints: + return [] + + def _name_matches(name: str, hint: str) -> bool: + name_n = name.strip() + hint_n = hint.strip() + if not hint_n: + return False + if name_n == hint_n: + return True + if hint_n.lower() in name_n.lower(): + return True + if Path(name_n).stem == Path(hint_n).stem: + return True + return False + + seen: set = set() + out: List[str] = [] + + def walk_dir(d: Path, depth: int) -> None: + if depth > max_depth or len(out) >= 20: + return + try: + for entry in sorted(d.iterdir(), key=lambda p: p.name): + if len(out) >= 20: + return + if entry.name.startswith("."): + continue + if entry.is_file(): + for hint in hints: + if _name_matches(entry.name, hint): + resolved = str(entry.resolve()) + if resolved not in seen: + seen.add(resolved) + out.append(resolved) + break + elif entry.is_dir(): + walk_dir(entry, depth + 1) + except PermissionError: + pass + + for p_str in paths: + p = Path(p_str).resolve() + if p.is_file(): + for hint in hints: + if _name_matches(p.name, hint): + resolved = str(p) + if resolved not in seen: + seen.add(resolved) + out.append(resolved) + break + elif p.is_dir(): + walk_dir(p, 0) + + return out + + async def _probe_dir_scan( + self, + paths: List[str], + enable: bool = True, + max_files: int = 500, + ): + """Scan directories for file metadata (filesystem only, no LLM). + + Automatically skips scanning when all *paths* are single files. + + Args: + paths: Normalised list of path strings to scan. + enable: Whether directory scanning is enabled. + max_files: Cap on number of files to scan (lower = faster). Returns: - ``(answer, cluster, context)`` — same triple as ``_search_deep`` - so the caller can handle both modes uniformly. + ScanResult or None if disabled / all paths are files. """ - context = SearchContext() - await self._logger.info(f"[FAST] Starting greedy search for: '{query[:80]}'") + if not enable or not self._has_directory_paths(paths): + return None - # ============================================================== - # Step 0: Cluster reuse — instant short-circuit (no LLM cost) - # When reuse succeeds we return here; no persistence step runs. - # ============================================================== - reused = await self._try_reuse_cluster(query, paths) - if reused is not None: - content = reused.content - if isinstance(content, list): - content = "\n".join(content) - await self._logger.success("[FAST] Reused cached knowledge cluster") - return str(content), reused, context + from sirchmunk.scan.dir_scanner import DirectoryScanner - # ============================================================== - # Step 1: LLM query analysis only (dir scan deferred until needed) - # ============================================================== - prompt = FAST_QUERY_ANALYSIS.format(user_input=query) - resp = await self.llm.achat( - messages=[{"role": "user", "content": prompt}], - stream=False, + if self._dir_scanner is None or self._dir_scanner.max_files != max_files: + self._dir_scanner = DirectoryScanner(llm=self.llm, max_files=max_files) + + await self._logger.info("[Probe:DirScan] Scanning directories...") + scan_result = await self._dir_scanner.scan(paths) + await self._logger.info( + f"[Probe:DirScan] Found {scan_result.total_files} files " + f"in {scan_result.total_dirs} dirs ({scan_result.scan_duration_ms:.0f}ms)" ) - self.llm_usages.append(resp.usage) - if resp.usage and isinstance(resp.usage, dict): - context.add_llm_tokens( - resp.usage.get("total_tokens", 0), usage=resp.usage, - ) + return scan_result - analysis = self._parse_fast_json(resp.content) - query_type = analysis.get("type", "search") - file_hints = analysis.get("file_hints", []) + async def _probe_knowledge_cache( + self, query: str, + ) -> KnowledgeProbeResult: + """Structured knowledge probe: embedding search with graph expansion. - if query_type == "chat": - chat_reply = analysis.get("response", "") - if chat_reply: - await self._logger.info("[FAST:Step1] LLM classified as chat intent") - return chat_reply, None, context - return (await self._respond_chat(query, context)) + Uses embedding similarity (threshold 0.50) when available, falling back + to SQL LIKE. Extracts file paths, topic keywords, and background + context from matched clusters and their graph neighbours. + """ + empty = KnowledgeProbeResult([], [], "") + try: + clusters: List[KnowledgeCluster] = [] - if query_type == "summary": - await self._logger.info("[FAST:Step1] Summary intent detected — delegating to doc analysis") - # When user names a specific file, resolve it and skip dir scan + rank - summary_paths: Optional[List[str]] = None - if file_hints: - summary_paths = self._resolve_file_hints(paths, file_hints) - if summary_paths: - await self._logger.info( - f"[FAST:Summary] Resolved file hint(s) → {[Path(p).name for p in summary_paths]}" + # Prefer embedding search for semantic quality + if self.embedding_client and self.embedding_client.is_ready(): + try: + qe = (await self.embedding_client.embed([query]))[0] + similar = await self.knowledge_storage.search_similar_clusters( + query_embedding=qe, top_k=5, similarity_threshold=0.50, ) - if summary_paths: - answer = await self._summarize_documents( - query, summary_paths, - top_k_files=len(summary_paths), - scan_result=None, + for m in (similar or []): + c = await self.knowledge_storage.get(m["id"]) + if c: + clusters.append(c) + except Exception: + pass + + # Fallback to SQL LIKE when embedding unavailable or empty + if not clusters: + clusters = await self.knowledge_storage.find(query, limit=3) + + if not clusters: + return empty + + seen_paths: set = set() + file_paths: List[str] = [] + extra_keywords: List[str] = [] + context_parts: List[str] = [] + seen_kw: set = set() + + def _collect_cluster(c: KnowledgeCluster) -> None: + for ev in getattr(c, "evidences", []): + fp = str(getattr(ev, "file_or_url", "")) + if fp and fp not in seen_paths and Path(fp).exists(): + seen_paths.add(fp) + file_paths.append(fp) + for p in getattr(c, "patterns", []) or []: + if p and p.lower() not in seen_kw: + seen_kw.add(p.lower()) + extra_keywords.append(p) + content = c.content + if isinstance(content, list): + content = "\n".join(content) + if content: + context_parts.append(str(content)[:500]) + + for c in clusters: + _collect_cluster(c) + + # One-hop graph expansion via WeakSemanticEdge + neighbour_ids: set = set() + for c in clusters: + for edge in getattr(c, "related_clusters", []): + tid = getattr(edge, "target_cluster_id", None) + if tid and tid not in neighbour_ids: + neighbour_ids.add(tid) + + for nid in list(neighbour_ids)[:6]: + try: + neighbour = await self.knowledge_storage.get(nid) + if neighbour: + _collect_cluster(neighbour) + except Exception: + pass + + if file_paths: + await self._logger.info( + f"[Probe:Knowledge] {len(file_paths)} files, " + f"{len(extra_keywords)} keywords from " + f"{len(clusters)} clusters + {len(neighbour_ids)} neighbours" ) - if answer: - return answer, self._make_answer_cluster(query, answer, "FS", file_paths=summary_paths), context - # No hint or resolve failed: run dir scan (if enabled) then rank + summarize - scan_result = await self._probe_dir_scan(paths, enable=enable_dir_scan, - max_files=300) if enable_dir_scan else None - answer = await self._summarize_documents( - query, paths, - top_k_files=top_k_files, - scan_result=scan_result, + + return KnowledgeProbeResult( + file_paths=file_paths, + extra_keywords=extra_keywords[:15], + background_context="\n\n".join(context_parts[:3]), ) - if answer: - return answer, self._make_answer_cluster(query, answer, "FS", file_paths=paths), context - await self._logger.info("[FAST:Step1] Summary fallback — no documents, continuing search") + except Exception: + return empty - primary = analysis.get("primary", [])[:2] - fallback = analysis.get("fallback", [])[:3] - primary_alt = analysis.get("primary_alt", [])[:2] - fallback_alt = analysis.get("fallback_alt", [])[:3] + def _load_cached_trees(self) -> list: + """Load DocumentTree objects from the tree cache directory. - if primary_alt: - primary = primary + primary_alt - if fallback_alt: - fallback = fallback + fallback_alt + Returns a list of ``DocumentTree`` instances whose file paths exist + on disk. Returns an empty list when the tree cache is absent or + contains no valid entries. + """ + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if not tree_cache.exists(): + return [] + try: + from sirchmunk.learnings.tree_indexer import DocumentTree - # --- IDF weights from LLM --- - keyword_idfs: Dict[str, float] = analysis.get("idf", {}) - if not keyword_idfs: - all_kws = (primary or []) + (fallback or []) - keyword_idfs = {kw: max(0.5, min(1.0, len(kw) / 5.0)) for kw in all_kws} + trees = [] + for tree_file in sorted(tree_cache.glob("*.json"))[:self._TREE_CACHE_SCAN_LIMIT]: + try: + t = DocumentTree.from_json( + tree_file.read_text(encoding="utf-8") + ) + if t.root and t.file_path and Path(t.file_path).exists(): + trees.append(t) + except Exception: + continue + return trees + except Exception: + return [] - if not primary and not fallback: - await self._logger.warning("[FAST] No keywords extracted") - msg = f"Could not extract search terms from query: '{query}'" - return msg, None, context + @staticmethod + def _prefilter_trees_by_query( + query: str, trees: list, max_candidates: int, min_score: float, + ) -> list: + """Rule-based pre-filter: score trees by query-token overlap with filenames. + + Extracts meaningful tokens from the query (alphanumeric words, 4-digit + years, multi-word entity fragments) and scores each tree's filename by + weighted token overlap. Returns the top-scoring candidates, or the + full list if fewer than *max_candidates* pass the threshold. + + This avoids sending hundreds of root summaries to the LLM. + """ + raw_tokens = re.findall(r"[A-Za-z0-9]+", query.lower()) + tokens = [t for t in raw_tokens if len(t) >= 2 and t not in _STOP_WORDS] + if not tokens: + return trees + + # Extract years: bare "2018" and compound prefixed forms "fy2018", "cy2023" + year_tokens: Set[str] = set() + for t in tokens: + if re.fullmatch(r"(?:19|20)\d{2}", t): + year_tokens.add(t) + else: + m = re.search(r"((?:19|20)\d{2})", t) + if m: + year_tokens.add(m.group(1)) + entity_tokens = {t for t in tokens if len(t) >= 2 and t not in year_tokens} + + scored: List[Tuple[float, int]] = [] + for idx, tree in enumerate(trees): + name_lower = Path(tree.file_path).stem.lower() + name_parts = set(re.findall(r"[a-z0-9]+", name_lower)) + + score = 0.0 + for tok in entity_tokens: + if tok in name_lower: + score += 2.0 + elif any(tok[:4] in part for part in name_parts if len(tok) >= 4): + score += 0.5 + for yr in year_tokens: + if yr in name_lower: + score += 3.0 + + scored.append((score, idx)) + + scored.sort(key=lambda x: -x[0]) + + candidates = [trees[idx] for sc, idx in scored if sc >= min_score] + if not candidates: + return [trees[idx] for _, idx in scored[:max_candidates]] + return candidates[:max_candidates] + + async def _llm_select_from_trees( + self, query: str, trees: list, max_select: int, + ) -> List[str]: + """Two-stage LLM-driven file selection from tree root summaries. + + Stage 1 (rule-based): when the pool exceeds ``_TREE_PREFILTER_THRESHOLD``, + narrow candidates by query-token / filename overlap. + Stage 2 (LLM): present root summaries of the narrowed set for precise selection. + + When the number of trees is at most *max_select*, returns all paths + without an LLM call. + """ + if not trees: + return [] + if len(trees) <= max_select: + return [t.file_path for t in trees] + + pool = trees + if len(pool) > self._TREE_PREFILTER_THRESHOLD: + pool = self._prefilter_trees_by_query( + query, pool, + max_candidates=self._TREE_PREFILTER_MAX_CANDIDATES, + min_score=self._TREE_PREFILTER_MIN_SCORE, + ) + if len(pool) <= max_select: + return [t.file_path for t in pool] + + listing = "\n".join( + f"[{i}] {Path(t.file_path).name}: " + f"{(t.root.summary or '')[:self._CATALOG_SUMMARY_TRUNCATE]}" + for i, t in enumerate(pool) + ) + prompt = ( + f'Given the query: "{query}"\n\n' + f"Select the 1-{max_select} most relevant documents " + f"(by index number):\n{listing}\n\n" + f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + ) + resp = await self.llm.achat([{"role": "user", "content": prompt}]) + self.llm_usages.append(resp.usage) + + selected_indices: List[int] = [] + try: + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + selected_indices = [ + idx for idx in json.loads(m.group()) + if isinstance(idx, int) and 0 <= idx < len(pool) + ] + except (json.JSONDecodeError, TypeError): + pass + + if not selected_indices: + selected_indices = list(range(min(max_select, len(pool)))) + + return [ + pool[idx].file_path + for idx in selected_indices[:max_select] + if Path(pool[idx].file_path).exists() + ] + + async def _probe_tree_index(self, query: str) -> List[str]: + """LLM-driven file discovery via compiled tree root summaries (PageIndex). + + Loads all cached document trees, presents their root summaries to the + LLM, and asks it to select the most relevant documents. Returns file + paths of the most relevant documents. + """ + try: + trees = self._load_cached_trees() + if not trees: + return [] + result = await self._llm_select_from_trees( + query, trees, max_select=self._DEEP_TREE_PROBE_MAX_FILES, + ) + if result: + await self._logger.info( + f"[Probe:TreeIndex] LLM selected {len(result)} documents " + f"from {len(trees)} tree indices" + ) + return result + except Exception: + return [] + + async def _probe_compile_hints( + self, + keywords: List[str], + *, + scope: Optional["_PathScope"] = None, + ) -> CompileHints: + """Zero-LLM enrichment from compile manifest and tree cache. + + Scans the compile manifest for clusters whose patterns overlap with + the query keywords, and scans cached tree root summaries for keyword + matches. No LLM calls — only local JSON reads and in-memory DB lookups. + + When *scope* is provided, only file paths falling within the scope + are included in the returned hints. + """ + empty = CompileHints([], []) + if not keywords: + return empty + + kw_lower = {k.lower() for k in keywords} + file_paths: List[str] = [] + extra_keywords: List[str] = [] + seen_paths: set = set() + seen_kw: set = set(kw_lower) + + def _accept(fp: str) -> bool: + return bool(fp) and fp not in seen_paths and Path(fp).exists() and ( + scope is None or scope.contains(fp) + ) + + # --- Cluster pattern matching via manifest --- + manifest_path = self.work_path / ".cache" / "compile" / "manifest.json" + if manifest_path.exists(): + try: + from sirchmunk.learnings.compiler import CompileManifest + manifest = CompileManifest.from_json( + manifest_path.read_text(encoding="utf-8") + ) + cluster_ids: set = set() + for entry in manifest.files.values(): + cluster_ids.update(entry.cluster_ids) + + for cid in list(cluster_ids)[:50]: + try: + c = await self.knowledge_storage.get(cid) + except Exception: + continue + if not c: + continue + cluster_patterns = [ + p.lower() for p in (getattr(c, "patterns", []) or []) if p + ] + if kw_lower & set(cluster_patterns): + for ev in getattr(c, "evidences", []): + fp = str(getattr(ev, "file_or_url", "")) + if _accept(fp): + seen_paths.add(fp) + file_paths.append(fp) + for p in cluster_patterns: + if p not in seen_kw: + seen_kw.add(p) + extra_keywords.append(p) + except Exception: + pass + + # --- Tree root summary scanning (keyword substring match) --- + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if tree_cache.exists(): + try: + from sirchmunk.learnings.tree_indexer import DocumentTree + for tree_file in sorted(tree_cache.glob("*.json"))[:100]: + try: + tree = DocumentTree.from_json( + tree_file.read_text(encoding="utf-8") + ) + except Exception: + continue + if not tree.root or not tree.file_path: + continue + summary_lower = (tree.root.summary or "").lower() + if any(kw in summary_lower for kw in kw_lower): + fp = tree.file_path + if _accept(fp): + seen_paths.add(fp) + file_paths.append(fp) + except Exception: + pass - await self._logger.info( - f"[FAST:Step1] Primary: {primary}, Fallback: {fallback}" + return CompileHints( + file_paths=file_paths[:15], + extra_keywords=extra_keywords[:10], ) - # ============================================================== - # Step 2: rga cascade — primary first, fallback only if needed - # Dir scan runs only when enabled, for fallback when rga misses. - # ============================================================== - context.add_search(query) - include_patterns = list(include or []) - for hint in file_hints: - if "*" in hint or "." in hint: - include_patterns.append(hint) + @staticmethod + def _merge_compile_hints(base: "CompileHints", supplement: "CompileHints") -> "CompileHints": + """Merge two CompileHints, deduplicating file paths and keywords.""" + seen_fps = set(base.file_paths) + merged_fps = list(base.file_paths) + for fp in supplement.file_paths: + if fp not in seen_fps: + seen_fps.add(fp) + merged_fps.append(fp) + seen_kws = set(base.extra_keywords) + merged_kws = list(base.extra_keywords) + for kw in supplement.extra_keywords: + if kw not in seen_kws: + seen_kws.add(kw) + merged_kws.append(kw) + return CompileHints(file_paths=merged_fps[:15], extra_keywords=merged_kws[:10]) + + async def _probe_summary_index( + self, + query: str, + artifacts: Optional["CompileArtifacts"] = None, + *, + scope: Optional["_PathScope"] = None, + ) -> List[str]: + """Zero-LLM file discovery via compile-time summary index (BM25 only). - rga_kwargs = dict( - paths=paths, max_depth=max_depth, - include=include_patterns or None, exclude=exclude, - ) + Uses the pre-built summary index's BM25 channel to find files whose + summaries are lexically similar to the query. No LLM or embedding + calls — pure local computation. - best_files: Optional[List[Dict[str, Any]]] = None - used_level = "primary" - evidence = "" + When *scope* is provided, results are post-filtered to only include + file paths within the search scope. - if primary: - best_files = await self._fast_find_best_file( - primary, top_k=top_k_files, keyword_idfs=keyword_idfs, **rga_kwargs - ) + Args: + query: User query string. + artifacts: Compile artifacts (uses summary_index field). + scope: Optional path scope for filtering results. - if not best_files and fallback: - used_level = "fallback" - await self._logger.info( - "[FAST:Step2] Primary miss, trying fine-grained fallback" - ) - best_files = await self._fast_find_best_file( - fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, **rga_kwargs + Returns: + File paths of top-k matching documents, or empty list. + """ + if artifacts is None or artifacts.summary_index is None: + return [] + + try: + from sirchmunk.utils.tokenizer_util import TokenizerUtil + _tokenizer = TokenizerUtil() + query_tokens = _tokenizer.segment(query) + + if not query_tokens: + return [] + + # BM25-only search: pass query_embedding=None to skip embedding channel + results = artifacts.summary_index.search( + query_embedding=None, + query_tokens=query_tokens, + top_k=self._SUMMARY_INDEX_TOP_K, ) - # --- Fallback: use dir_scan only when rga misses and dir scan is enabled --- - if not best_files and enable_dir_scan: - scan_result = await self._probe_dir_scan(paths, enable=True, max_files=300) - if scan_result is not None: - await self._logger.info("[FAST:Step2] rga miss — falling back to dir_scan ranking") - ranked_paths = await self._rank_dir_scan_candidates( - query, scan_result, top_k=10, include_medium=True, - ) - if ranked_paths: - used_level = "dir_scan" - best_files = [{"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} for p in ranked_paths[:top_k_files]] + file_paths = [ + fp for fp, score in results + if score > 0.0 and Path(fp).exists() + and (scope is None or scope.contains(fp)) + ] - if not best_files: - if llm_fallback: + if file_paths: await self._logger.info( - "[FAST:Step2] No files found, llm_fallback=True \u2192 skip to LLM summary" - ) - evidence = self._LLM_FALLBACK_EVIDENCE - else: - await self._logger.warning( - f"[FAST:Step2] No matching files found in paths: {paths}. " + f"[SummaryIndex:BM25] Found {len(file_paths)} files " + f"from {artifacts.summary_index.num_entries} indexed docs" ) - return _NO_RESULTS_MESSAGE, None, context + return file_paths + except Exception as exc: + await self._logger.warning(f"[SummaryIndex:BM25] Probe failed: {exc}") + return [] - if best_files: - file_path = best_files[0]["path"] - match_objects = best_files[0].get("matches", []) - await self._logger.info( - f"[FAST:Step2] Best file ({used_level}): {Path(file_path).name} " - f"({best_files[0].get('total_matches', 0)} hits, score={best_files[0].get('weighted_score', 0):.2f})" - ) + async def _probe_catalog_for_deep( + self, + query: str, + artifacts: Optional["CompileArtifacts"] = None, + ) -> List[str]: + """Zero-LLM file discovery via document catalog keyword overlap. - # ============================================================== - # Step 3: Context sampling around grep hits (no LLM) - # Multi-file evidence aggregation - # ============================================================== - evidence_parts = [] - total_evidence_chars = 0 - for bf in best_files: - if total_evidence_chars >= self._FAST_MAX_EVIDENCE_CHARS: - break + Scores each catalog entry by counting query token overlap with the + document summary. Returns top-k file paths sorted by overlap score. - file_path = bf["path"] - fname = Path(file_path).name - ext = Path(file_path).suffix.lower() + Args: + query: User query string. + artifacts: Compile artifacts (uses catalog field). - # Small file short-circuit: read full content instead of grep sampling - ev = None - if ext in self._FAST_TEXT_EXTENSIONS: - try: - file_size = Path(file_path).stat().st_size - if file_size < self._FAST_SMALL_FILE_THRESHOLD: - full_text = Path(file_path).read_text(errors="replace") - if len(full_text) < self._FAST_SMALL_FILE_THRESHOLD: - ev = f"[{fname}]\n{full_text}" - await self._logger.info( - f"[FAST] Small file short-circuit: reading full content of {fname} " - f"({len(full_text)} chars)" - ) - except Exception: - pass # Fall through to normal evidence extraction + Returns: + File paths of top-k matching documents, or empty list. + """ + if not artifacts or not artifacts.catalog: + return [] - # Normal path: grep-based evidence sampling - if ev is None: - ev = await self._fast_sample_evidence(file_path, bf.get("matches", [])) + try: + query_tokens = self._tokenize_for_matching(query.lower()) + if not query_tokens: + return [] - if ev: - remaining = self._FAST_MAX_EVIDENCE_CHARS - total_evidence_chars - chunk = ev[:remaining] - evidence_parts.append(chunk) - total_evidence_chars += len(chunk) - context.mark_file_read(file_path) + scored: List[Tuple[str, float]] = [] + for entry in artifacts.catalog: + fp = entry.get("path", "") + if not fp or not Path(fp).exists(): + continue + summary = (entry.get("summary", "") or "").lower() + name = (entry.get("name", "") or "").lower() + doc_tokens = self._tokenize_for_matching(f"{name} {summary}") + overlap = len(query_tokens & doc_tokens) + if overlap > 0: + # Normalize by query length to avoid bias toward long summaries + score = overlap / max(1, len(query_tokens)) + scored.append((fp, score)) + + if not scored: + return [] - evidence = "\n\n---\n\n".join(evidence_parts) + scored.sort(key=lambda x: x[1], reverse=True) + result_paths = [fp for fp, _ in scored[:self._DEEP_CATALOG_TOP_K]] - if not evidence or len(evidence.strip()) < 20: - if llm_fallback: - await self._logger.info( - "[FAST:Step3] No usable evidence, llm_fallback=True \u2192 LLM summary" - ) - evidence = self._LLM_FALLBACK_EVIDENCE - else: - await self._logger.warning("[FAST:Step3] No usable evidence extracted") - return _NO_RESULTS_MESSAGE, None, context + if result_paths: + await self._logger.info( + f"[DEEP:CatalogProbe] Found {len(result_paths)} files " + f"from {len(artifacts.catalog)} catalog entries" + ) + return result_paths + except Exception as exc: + await self._logger.warning(f"[DEEP:CatalogProbe] Failed: {exc}") + return [] - await self._logger.info( - f"[FAST:Step3] Evidence: {len(evidence)} chars from {Path(file_path).name}" + async def _probe_tree_for_fast( + self, query: str, artifacts: Optional["CompileArtifacts"] = None, + ) -> List[str]: + """Active tree-based file discovery for FAST mode (1 LLM call). + + When compiled tree indices are available and cover more than 2 files, + asks the LLM to select the most relevant 1-2 documents from root + summaries. Delegates to the shared ``_llm_select_from_trees`` helper. + + Returns file paths of selected documents, or empty list when trees + are unavailable or cover too few files to justify an LLM call. + """ + print(f"SEARCH_WIKI_DEBUG [D4] _probe_tree_for_fast: tree_available_paths={len(artifacts.tree_available_paths) if artifacts else 0}", flush=True) + if not artifacts or not artifacts.tree_available_paths: + return [] + + try: + trees = self._load_cached_trees() + # Scope-filter: only keep trees whose files are in artifacts + if artifacts and artifacts.tree_available_paths: + scoped = artifacts.tree_available_paths + trees = [t for t in trees if t.file_path in scoped] + print(f"SEARCH_WIKI_DEBUG [D5] loaded_trees: {len(trees)} trees, paths={[t.file_path for t in trees][:3]}", flush=True) + if not trees: + return [] + result = await self._llm_select_from_trees( + query, trees, max_select=self._FAST_TREE_PROBE_MAX_FILES, ) + print(f"SEARCH_WIKI_DEBUG [D6] llm_select_result: {result}", flush=True) + if result: + await self._logger.info( + f"[FAST:TreeProbe] Selected {len(result)} files " + f"from {len(trees)} tree indices" + ) + return result + except Exception as exc: + await self._logger.warning(f"[FAST:TreeProbe] Failed: {exc}") + return [] - keywords_used = primary if used_level == "primary" else fallback + @staticmethod + async def _async_noop(default=None): + """No-op coroutine used as placeholder in gather().""" + return default - # ============================================================== - # Step 4: LLM answer from focused evidence (single call) - # ============================================================== - answer_prompt = ROI_RESULT_SUMMARY.format( - user_input=query, - text_content=evidence, - ) - answer_resp = await self.llm.achat( - messages=[{"role": "user", "content": answer_prompt}], - stream=True, + # ------------------------------------------------------------------ + # Phase 2 retrievers + # ------------------------------------------------------------------ + + async def _retrieve_by_keywords( + self, + keywords: List[str], + paths: List[str], + max_depth: Optional[int] = 5, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ) -> List[str]: + """Run keyword search via rga and return discovered file paths. + + Each keyword is searched concurrently (literal per-term strategy). + """ + from sirchmunk.agentic.tools import KeywordSearchTool + + tool = KeywordSearchTool( + retriever=self.grep_retriever, + paths=paths, + max_depth=max_depth if max_depth is not None else 5, + max_results=20, + include=include, + exclude=exclude, ) - self.llm_usages.append(answer_resp.usage) - if answer_resp.usage and isinstance(answer_resp.usage, dict): - context.add_llm_tokens( - answer_resp.usage.get("total_tokens", 0), usage=answer_resp.usage, - ) + ctx = SearchContext() # lightweight context for this probe + result_text, meta = await tool.execute(context=ctx, keywords=keywords) - answer, should_save, should_answer = self._parse_summary_response( - answer_resp.content or "" + # Extract discovered file paths from the tool's context logs + discovered: List[str] = [] + for log_entry in ctx.retrieval_logs: + discovered.extend(log_entry.metadata.get("files_discovered", [])) + + await self._logger.info( + f"[Retrieve:Keywords] {len(discovered)} files from rga search" ) - if not should_answer: - if llm_fallback: - await self._logger.info( - "[FAST:Step4] Summary gate rejected evidence, llm_fallback=True → LLM fallback" - ) - answer, should_save = await self._summarise_fast_fallback(query, context) - else: - await self._logger.warning( - "[FAST:Step4] Summary gate rejected evidence and llm_fallback=False " - "→ returning no results" - ) - return _NO_RESULTS_MESSAGE, None, context - if not should_save: - await self._logger.info("[FAST] Quality gate: low-quality answer, skipping cluster save") - await self._logger.success("[FAST] Search complete (2 LLM calls, no persist)") - return answer, None, context + return discovered - cluster = self._build_fast_cluster( - query, answer, file_path, evidence, keywords_used, + async def _rank_dir_scan_candidates( + self, + query: str, + scan_result, + *, + top_k: int = 20, + include_medium: bool = False, + ) -> List[str]: + """Run LLM ranking on dir_scan candidates and return relevant paths. + + Args: + include_medium: When True, include both high and medium relevance. + """ + if self._dir_scanner is None: + return [] + + ranked = await self._dir_scanner.rank(query, scan_result, top_k=top_k) + accept = {"high", "medium"} if include_medium else {"high"} + paths = [ + c.path for c in ranked.ranked_candidates + if c.relevance in accept + ] + await self._logger.info( + f"[Retrieve:DirScan] {len(paths)} relevant files " + f"(accept={accept})" ) - self._add_query_to_cluster(cluster, query) - try: - await self._save_cluster_with_embedding(cluster) - except Exception as exc: - _loguru_logger.warning( - f"[FAST] Failed to save cluster with embedding: {exc}" - ) + return paths - await self._logger.success("[FAST] Search complete (2 LLM calls)") - return answer, cluster, context + async def _scan_and_rank_paths( + self, + query: str, + paths: List[str], + *, + max_files: int = 300, + top_k: int = 20, + include_medium: bool = True, + ) -> List[str]: + """Scan directories and return LLM-ranked relevant file paths. - # ---- FAST helpers ---- + Combines :meth:`_probe_dir_scan` (filesystem walk) and + :meth:`_rank_dir_scan_candidates` (LLM ranking) in one call. + Automatically skips scanning when all *paths* are single files. - @staticmethod - def _count_keyword_tf_per_file(raw_results: List[Dict[str, Any]]) -> Dict[str, int]: - """Count matches per file from rga JSON output.""" - counts: Dict[str, int] = {} - current_path: Optional[str] = None - for item in raw_results: - item_type = item.get("type") - if item_type == "begin": - current_path = item.get("data", {}).get("path", {}).get("text") - elif item_type == "match" and current_path is not None: - counts[current_path] = counts.get(current_path, 0) + 1 - elif item_type == "end": - current_path = None - return counts + Returns: + Ranked file paths (high + optionally medium relevance), + or empty list when scanning is not applicable. + """ + scan_result = await self._probe_dir_scan( + paths, enable=True, max_files=max_files, + ) + if scan_result is None: + return [] - @staticmethod - def _dedup_merged_files( - merged: List[Dict[str, Any]], - per_file_kw_tf: Dict[str, Dict[str, int]], - match_limit: int = 20, - ) -> List[Dict[str, Any]]: - """Deduplicate merged file entries by path, combining matches from - multiple keyword searches into a single entry per file. + return await self._rank_dir_scan_candidates( + query, scan_result, + top_k=top_k, include_medium=include_medium, + ) - When the same file appears in multiple rga begin/end groups (one per - keyword search), this merges them so downstream scoring and evidence - extraction operate on a single, complete representation. + # ------------------------------------------------------------------ + # Phase 3: Merge + cluster build + # ------------------------------------------------------------------ - Args: - merged: File entries from GrepRetriever.merge_results(), may - contain duplicates. - per_file_kw_tf: Pre-computed per-file keyword TF counts (not - modified, used only for reference). - match_limit: Maximum matches to keep per file after merging. + @staticmethod + def _merge_file_paths( + keyword_files: List[str], + dir_scan_files: List[str], + knowledge_hits: List[str], + ) -> List[str]: + """Merge file paths from all retrieval paths, dedup, preserve priority. - Returns: - Deduplicated list with one entry per unique file path. + Priority: keyword_search > knowledge_cache > dir_scan. """ - if not merged: - return merged + seen: set = set() + merged: List[str] = [] - seen: Dict[str, int] = {} # path -> index in deduped - deduped: List[Dict[str, Any]] = [] + for fp in keyword_files + knowledge_hits + dir_scan_files: + if fp and fp not in seen: + seen.add(fp) + merged.append(fp) - for entry in merged: - fpath = entry["path"] - if fpath in seen: - # Merge into existing entry - idx = seen[fpath] - existing = deduped[idx] - existing["matches"].extend(entry.get("matches", [])) - existing["lines"].extend(entry.get("lines", [])) - existing["total_matches"] += entry.get("total_matches", 0) - else: - # New file — clone to avoid mutating original - seen[fpath] = len(deduped) - deduped.append({ - "path": fpath, - "matches": list(entry.get("matches", [])), - "lines": list(entry.get("lines", [])), - "total_matches": entry.get("total_matches", 0), - "total_score": entry.get("total_score", 0.0), - }) + return merged - # Trim matches to limit per file - for entry in deduped: - if len(entry["matches"]) > match_limit: - # Sort by score descending, keep top - entry["matches"].sort( - key=lambda x: x.get("score", 0.0), reverse=True - ) - entry["matches"] = entry["matches"][:match_limit] + def _get_tree_indexer(self): + """Lazily construct a DocumentTreeIndexer for search-time tree navigation.""" + from sirchmunk.learnings.tree_indexer import DocumentTreeIndexer - return deduped + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if not tree_cache.exists(): + return None + _cb = getattr(self._logger, 'log_callback', None) + return DocumentTreeIndexer( + llm=self.llm, + cache_dir=tree_cache, + log_callback=_cb, + ) - @staticmethod - def _prune_by_score( - candidates: List[Dict[str, Any]], - top_k: int = 3, - relative_ratio: float = 0.30, - gap_ratio: float = 0.50, - min_count: int = 1, - ) -> List[Dict[str, Any]]: - """Dynamically prune ranked file candidates by score distribution. + async def _build_cluster( + self, + query: str, + file_paths: List[str], + query_keywords: Dict[str, float], + top_k_files: int = 5, + top_k_snippets: int = 5, + ) -> Optional[KnowledgeCluster]: + """Build a KnowledgeCluster via knowledge_base.build(). - Applies a three-stage filter to remove clearly irrelevant files: + Constructs the Request wrapper and delegates to the knowledge + base for parallel Monte Carlo evidence sampling. When compiled + tree indices exist, passes a ``tree_indexer`` so that evidence + extraction can navigate to relevant sections before sampling. + """ + try: + request = Request( + messages=[ + Message( + role="user", + content=[ContentItem(type="text", text=query)], + ), + ], + ) + retrieved_infos = [{"path": fp} for fp in file_paths] - 1. **Relative threshold**: Discard files scoring below - ``max_score * relative_ratio`` (default 30%). - 2. **Gap detection**: Scan adjacently ranked files; when the score - drop from one to the next exceeds ``prev_score * gap_ratio`` - (default 50%), truncate the list at that point. - 3. **Minimum guarantee**: Ensure at least ``min_count`` files - survive (default 1). + cluster = await self.knowledge_base.build( + request=request, + retrieved_infos=retrieved_infos, + keywords=query_keywords, + top_k_files=top_k_files, + top_k_snippets=top_k_snippets, + verbose=self.verbose, + tree_indexer=self._get_tree_indexer(), + ) + self.llm_usages.extend(self.knowledge_base.llm_usages) + self.knowledge_base.llm_usages.clear() - Finally the result is capped at ``top_k``. + if cluster: + await self._logger.success( + f"[Phase 3] KnowledgeCluster built: {cluster.name} " + f"({len(cluster.evidences)} evidence units)" + ) + return cluster + except Exception as exc: + await self._logger.warning(f"[Phase 3] knowledge_base.build() failed: {exc}") + return None - Args: - candidates: File dicts sorted by ``weighted_score`` descending. - top_k: Maximum number of files to return. - relative_ratio: Fraction of the top score used as a floor. - gap_ratio: Maximum tolerated relative drop between adjacent - candidates. - min_count: Minimum number of candidates to keep regardless of - score. + async def _gather_graph_context(self, cluster: KnowledgeCluster) -> str: + """Enrich answer context with knowledge from graph neighbours. - Returns: - Pruned list of candidates (length in [min_count, top_k]). + Traverses the cluster's ``related_clusters`` edges (sorted by weight), + fetches the top neighbours, and returns a joined summary string that + can be appended to the cluster content before answer generation. """ - if not candidates: - return [] + edges = sorted( + getattr(cluster, "related_clusters", []) or [], + key=lambda e: getattr(e, "weight", 0), + reverse=True, + ) + if not edges: + return "" - max_score = candidates[0].get("weighted_score", 0.0) + parts: List[str] = [] + for edge in edges[:3]: + tid = getattr(edge, "target_cluster_id", None) + if not tid: + continue + try: + neighbour = await self.knowledge_storage.get(tid) + except Exception: + continue + if not neighbour: + continue + content = neighbour.content + if isinstance(content, list): + content = "\n".join(content) + name = getattr(neighbour, "name", "") or "" + snippet = str(content or "")[:300] + if snippet: + parts.append(f"- {name}: {snippet}") - # Step 1: Relative threshold filter - threshold = max_score * relative_ratio - filtered = [f for f in candidates if f.get("weighted_score", 0.0) >= threshold] - if not filtered: - filtered = candidates[:min_count] + if not parts: + return "" + await self._logger.info( + f"[Phase 3.5] Graph context: {len(parts)} neighbour summaries" + ) + return "Related knowledge:\n" + "\n".join(parts) - # Step 2: Gap detection truncation - result = [filtered[0]] - for i in range(1, len(filtered)): - prev_score = filtered[i - 1].get("weighted_score", 0.0) - curr_score = filtered[i].get("weighted_score", 0.0) - if prev_score > 0 and (prev_score - curr_score) > prev_score * gap_ratio: - break - result.append(filtered[i]) + # ------------------------------------------------------------------ + # Phase 4: Answer generation + # ------------------------------------------------------------------ - # Step 3: Minimum guarantee - if len(result) < min_count and len(filtered) >= min_count: - result = filtered[:min_count] + _INTENT_PROMPT_MAP = { + "lookup": ROI_LOOKUP_SYNTHESIS, + "computation": ROI_COMPUTATION_SYNTHESIS, + "comparison": ROI_COMPARISON_SYNTHESIS, + } - # Cap at top_k - return result[:top_k] + @classmethod + def _select_synthesis_prompt( + cls, + query: str, + evidence: str, + intent: str = "", + *, + document_context: Optional[str] = None, + ) -> str: + """Select and format the synthesis prompt based on query intent. - async def _fast_find_best_file( - self, - keywords: List[str], - paths: List[str], - max_depth: Optional[int] = 5, - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - top_k: int = 1, - keyword_idfs: Optional[Dict[str, float]] = None, - ) -> Optional[List[Dict[str, Any]]]: - """Search per keyword via rga and return the top-k best-matching files - ranked by IDF-weighted log-TF scoring. + Falls back to ``ROI_RESULT_SUMMARY`` for unknown intents or when + the caller passes no intent (FAST mode compatibility). + """ + template = cls._INTENT_PROMPT_MAP.get(intent, ROI_RESULT_SUMMARY) + + prompt = template.format(user_input=query, text_content=evidence) + + if document_context: + prompt = ( + f"{prompt}\n\n### Document Context\n{document_context}" + ) + return prompt + + async def _summarise_cluster( + self, query: str, cluster: KnowledgeCluster, + intent: str = "", + ) -> Tuple[str, bool, bool]: + """Generate a final answer summary from a KnowledgeCluster. + + When *intent* is provided, selects a specialised synthesis prompt + (lookup / computation / comparison). Falls back to the general + ``ROI_RESULT_SUMMARY`` for FAST mode or unknown intents. Returns: - List of merged file dicts (path, matches, lines, total_matches, weighted_score) or None. + ``(summary_text, should_save, should_answer)`` where: + - should_save: quality verdict for persistence + - should_answer: evidence sufficiency verdict for answering """ - all_raw: List[Dict[str, Any]] = [] - per_file_kw_tf: Dict[str, Dict[str, int]] = {} # {file_path: {keyword: count}} + sep = "\n" + cluster_text_content = ( + f"{cluster.name}\n\n" + f"{sep.join(cluster.description)}\n\n" + f"{cluster.content if isinstance(cluster.content, str) else sep.join(cluster.content)}" + ) - for kw in keywords: - try: - results = await self.grep_retriever.retrieve( - terms=kw, path=paths, literal=True, regex=False, - max_depth=max_depth, include=include, exclude=exclude, - timeout=30.0, - ) - if results: - all_raw.extend(results) - # Track per-file TF for this keyword - kw_counts = self._count_keyword_tf_per_file(results) - for fpath, count in kw_counts.items(): - per_file_kw_tf.setdefault(fpath, {})[kw] = count - except Exception as exc: - await self._logger.warning( - f"[FAST] rga literal search failed for '{kw}': {exc}" - ) + result_sum_prompt = self._select_synthesis_prompt( + query, cluster_text_content, intent, + ) - # Fallback: escaped-regex OR (handles adapters that only work in regex mode) - if not all_raw and keywords: - try: - escaped = [re.escape(kw) for kw in keywords] - pattern = "|".join(escaped) - results = await self.grep_retriever.retrieve( - terms=pattern, path=paths, literal=False, regex=True, - max_depth=max_depth, include=include, exclude=exclude, - timeout=30.0, - ) - if results: - all_raw.extend(results) - # For regex OR fallback, attribute matches to individual keywords - # by checking which keywords appear in each match line - # (simplified: count total matches per file, distribute proportionally) - regex_counts = self._count_keyword_tf_per_file(results) - for fpath, count in regex_counts.items(): - # Attribute to all keywords equally (approximation for OR regex) - per_kw_share = max(1, count // len(keywords)) if keywords else count - for kw in keywords: - existing = per_file_kw_tf.get(fpath, {}).get(kw, 0) - if existing == 0: # Only fill if not already set by literal search - per_file_kw_tf.setdefault(fpath, {})[kw] = per_kw_share - except Exception as exc: - await self._logger.warning( - f"[FAST] rga regex search failed: {exc}" - ) + await self._logger.info("[Phase 4] Generating search result summary...") + response = await self.llm.achat( + messages=[{"role": "user", "content": result_sum_prompt}], + stream=True, + ) + self.llm_usages.append(response.usage) - # Fallback: filename search - if not all_raw: - try: - fn_results = await self.grep_retriever.retrieve_by_filename( - patterns=[f".*{re.escape(kw)}.*" for kw in keywords], - path=paths, case_sensitive=False, max_depth=max_depth, - timeout=30.0, - ) - if fn_results: - return [{"path": fn_results[0]["path"], "matches": [], "lines": [], "total_matches": 0, "weighted_score": 0.0}] - except Exception as exc: - await self._logger.warning( - f"[FAST] filename search failed: {exc}" - ) - return None + summary, should_save, should_answer = self._parse_summary_response(response.content) + return summary, should_save, should_answer - merged = GrepRetriever.merge_results(all_raw, limit=20) - if not merged: - return None + async def _summarise_cluster_fallback(self, query: str) -> Tuple[str, bool]: + """Generate an answer using the ROI summary prompt with fallback evidence. - # Deduplicate file entries from multi-keyword searches - merged = self._dedup_merged_files(merged, per_file_kw_tf) + Feeds the standard fallback text so the LLM answers from its own + knowledge without adding an extra LLM call to the pipeline. + """ + result_sum_prompt = ROI_RESULT_SUMMARY.format( + user_input=query, + text_content=self._LLM_FALLBACK_EVIDENCE, + ) + await self._logger.info("[Phase 4] Generating fallback summary from LLM knowledge...") + response = await self.llm.achat( + messages=[{"role": "user", "content": result_sum_prompt}], + stream=True, + ) + self.llm_usages.append(response.usage) + summary, _, _ = self._parse_summary_response(response.content) + return summary, False # Never save fallback answers - # --- IDF × (1 + log TF) weighted scoring --- - _idfs = keyword_idfs or {} - for f in merged: - fpath = f["path"] - kw_tf = per_file_kw_tf.get(fpath, {}) - score = 0.0 - for kw in keywords: - tf = kw_tf.get(kw, 0) - if tf > 0: - idf = _idfs.get(kw, max(0.5, min(1.0, len(kw) / 5.0))) - score += idf * (1.0 + math.log(tf)) - f["weighted_score"] = score + async def _summarise_fast_fallback( + self, query: str, context: "SearchContext", + ) -> Tuple[str, bool]: + """Generate an answer using the FAST summary prompt with fallback evidence. - merged.sort(key=lambda f: f["weighted_score"], reverse=True) - pruned = self._prune_by_score(merged, top_k=top_k) + Reuses the existing ``ROI_RESULT_SUMMARY`` prompt, feeding it the + standard fallback text so that the LLM answers from its own knowledge. + """ + answer_prompt = ROI_RESULT_SUMMARY.format( + user_input=query, + text_content=self._LLM_FALLBACK_EVIDENCE, + ) + answer_resp = await self.llm.achat( + messages=[{"role": "user", "content": answer_prompt}], + stream=True, + ) + self.llm_usages.append(answer_resp.usage) + if answer_resp.usage and isinstance(answer_resp.usage, dict): + context.add_llm_tokens( + answer_resp.usage.get("total_tokens", 0), usage=answer_resp.usage, + ) + answer, _, _ = self._parse_summary_response(answer_resp.content or "") + return answer, False # Never save fallback answers - return pruned if pruned else None + # ------------------------------------------------------------------ + # Agentic retrieval pipeline (DEEP mode) + # ------------------------------------------------------------------ - async def _fast_sample_evidence( + async def _analyze_data_requirements( + self, query: str, intent: str, + ) -> DataRequirements: + """Identify what data points the query needs before any retrieval.""" + try: + prompt = DEEP_DATA_REQUIREMENTS.format(query=query, intent=intent) + resp = await self.llm.achat( + [{"role": "user", "content": prompt}], stream=False, + ) + self.llm_usages.append(resp.usage) + raw = (resp.content or "").strip() + data = self._extract_json_object(raw) + if data: + return DataRequirements( + data_points=data.get("data_points", [query]), + likely_sources=data.get("likely_sources", []), + formula=data.get("formula"), + time_period=data.get("time_period"), + intent=intent, + ) + except Exception as exc: + await self._logger.warning( + f"[Phase 3] Data requirements analysis failed: {exc}" + ) + return DataRequirements( + data_points=[query], likely_sources=[], formula=None, + time_period=None, intent=intent, + ) + + def _select_target_files( self, - file_path: str, - match_objects: List[Dict[str, Any]], - ) -> str: - """Build focused evidence from grep hits: context windows for text - files, raw match snippets for binary formats. + merged_files: List[str], + scope: "_PathScope", + artifacts: Optional["CompileArtifacts"], + ) -> List[str]: + """Select top files for agentic retrieval, preferring tree-indexed ones.""" + scoped = [fp for fp in merged_files if scope.contains(fp)] + if not scoped: + scoped = list(merged_files) - Args: - file_path: Absolute path to the best file. - match_objects: Match event dicts from ``merge_results``. + tree_paths = ( + artifacts.tree_available_paths if artifacts else set() + ) + with_tree = [fp for fp in scoped if fp in tree_paths] + without_tree = [fp for fp in scoped if fp not in tree_paths] + ranked = with_tree + without_tree + return ranked[: self._AGENTIC_MAX_FILES] - Returns: - Formatted evidence string. - """ - fname = Path(file_path).name - ext = Path(file_path).suffix.lower() + async def _select_pages_for_data( + self, + query: str, + data_reqs: DataRequirements, + section_map: str, + evidence_so_far: str, + fetched_pages: set, + sections_meta: Optional[List[Dict[str, Any]]] = None, + total_pages: Optional[int] = None, + ) -> set: + """LLM-driven page selection given document outline and data needs.""" + reqs_str = "\n".join( + f"- {dp}" for dp in data_reqs.data_points + ) + if data_reqs.formula: + reqs_str += f"\nFormula: {data_reqs.formula}" - # Extract match line numbers - hit_lines: List[int] = [] - for m in match_objects: - ln = m.get("data", {}).get("line_number") - if isinstance(ln, int): - hit_lines.append(ln) + fetched_str = ( + ", ".join(str(p) for p in sorted(fetched_pages)) + if fetched_pages else "None" + ) - # Diagnostic logging when falling back to snippet mode - if not hit_lines and match_objects: + evidence_summary = ( + evidence_so_far[:2000] if evidence_so_far.strip() else "None yet" + ) + + prompt = DEEP_PAGE_SELECT.format( + query=query, + data_requirements=reqs_str, + section_map=section_map, + fetched_pages=fetched_str, + evidence_summary=evidence_summary, + ) + try: + resp = await self.llm.achat( + [{"role": "user", "content": prompt}], stream=False, + ) + self.llm_usages.append(resp.usage) + raw = (resp.content or "").strip() + match = re.search(r"\[[\d\s,]+\]", raw) + if match: + pages = json.loads(match.group()) + result = { + int(p) for p in pages + if isinstance(p, (int, float)) and int(p) > 0 + } + if result: + return result + except Exception as exc: await self._logger.warning( - f"[FAST] No line_number in {len(match_objects)} match(es) for {fname}, " - f"falling back to snippet mode" + f"[Phase 4] Page selection failed: {exc}" ) - # --- Text files: read context windows around hits --- - if ext in self._FAST_TEXT_EXTENSIONS and hit_lines: - # Expand context window for sparse hits - window = self._FAST_CONTEXT_WINDOW - if len(hit_lines) <= 2: - window = max(window, 100) # ±100 lines for 1-2 hits - evidence = self._read_context_windows( - file_path, hit_lines, - window=window, - max_chars=self._FAST_MAX_EVIDENCE_CHARS, + return self._fallback_page_selection( + data_reqs, sections_meta, fetched_pages, total_pages, + ) + + @staticmethod + def _fallback_page_selection( + data_reqs: DataRequirements, + sections_meta: Optional[List[Dict[str, Any]]], + fetched_pages: set, + total_pages: Optional[int], + ) -> set: + """Heuristic page selection when LLM fails or returns empty.""" + candidates: set = set() + + if sections_meta: + source_keywords = { + s.lower() for s in data_reqs.likely_sources + } + for sec in sections_meta: + title_lower = (sec.get("title") or "").lower() + pr = sec.get("page_range") + if not pr or not pr[0]: + continue + if any(kw in title_lower for kw in source_keywords): + start, end = int(pr[0]), int(pr[1]) + for p in range(start, min(start + 4, end + 1)): + candidates.add(p) + + if not candidates and total_pages and total_pages > 10: + mid = total_pages // 2 + last_quarter = total_pages * 3 // 4 + for p in range(mid, min(mid + 4, total_pages + 1)): + candidates.add(p) + for p in range(last_quarter, min(last_quarter + 4, total_pages + 1)): + candidates.add(p) + + return candidates - fetched_pages + + async def _check_data_requirements( + self, + query: str, + data_reqs: DataRequirements, + evidence: str, + ) -> Tuple[bool, List[str]]: + """Check if evidence satisfies all data requirements. + + Returns ``(is_complete, missing_data_points)``. + """ + try: + prompt = DEEP_CHECK_REQUIREMENTS.format( + query=query, + data_points="\n".join(f"- {dp}" for dp in data_reqs.data_points), + formula=data_reqs.formula or "N/A", + evidence=evidence[:self._AGENTIC_EVIDENCE_MAX_CHARS], ) - if evidence: - full_evidence = f"[{fname}]\n{evidence}" - if len(full_evidence) < 100: - await self._logger.info( - f"[FAST] Context window evidence too thin ({len(full_evidence)} chars) for {fname}, " - f"attempting file head extraction" + resp = await self.llm.achat( + [{"role": "user", "content": prompt}], stream=False, + ) + self.llm_usages.append(resp.usage) + raw = (resp.content or "").strip() + json_start = raw.find("{") + json_end = raw.rfind("}") + if json_start >= 0 and json_end > json_start: + data = json.loads(raw[json_start : json_end + 1]) + is_complete = bool(data.get("complete", True)) + missing = data.get("missing", []) + if isinstance(missing, list) and missing: + return False, [str(m) for m in missing[:5]] + return is_complete, [] + except Exception as exc: + await self._logger.warning( + f"[Phase 4] Data requirements check failed: {exc}" + ) + return True, [] + + async def _agentic_retrieve( + self, + query: str, + data_reqs: DataRequirements, + target_files: List[str], + context: "SearchContext", + ) -> RetrievalResult: + """Core agentic retrieval loop: select pages → extract → check → repeat.""" + indexer = self._get_tree_indexer() + evidence_parts: List[str] = [] + pages_extracted: Dict[str, set] = {} + total_pages = 0 + + outlines: Dict[str, str] = {} + outlines_meta: Dict[str, List[Dict[str, Any]]] = {} + file_total_pages: Dict[str, int] = {} + outline_target_files: List[str] = [] + + for fp in target_files: + tree = indexer.load_tree(fp) if indexer else None + if tree and tree.total_pages: + file_total_pages[fp] = tree.total_pages + if fp not in file_total_pages: + try: + from pypdf import PdfReader + file_total_pages[fp] = len(PdfReader(fp).pages) + except Exception: + pass + + tp = file_total_pages.get(fp) + if tp and tp <= self._SHORT_DOC_THRESHOLD: + fname = Path(fp).name + try: + all_pages = list(range(1, tp + 1)) + contents = DocumentExtractor.extract_pages(fp, all_pages) + for pc in contents: + if pc.content and pc.content.strip(): + evidence_parts.append( + f"[{fname} p.{pc.page_number}]\n{pc.content}" + ) + pages_extracted[fp] = set(all_pages) + total_pages += tp + except Exception as exc: + await self._logger.warning( + f"[Phase 4] Full extraction of short doc {fname} failed: {exc}" ) - head_evidence = await self._fast_read_file_head(file_path) - if head_evidence and len(head_evidence) > len(full_evidence): - return head_evidence - return full_evidence + outline_target_files.append(fp) + continue + outline_target_files.append(fp) - # --- Non-text files or no line numbers: use grep snippets --- - snippets: List[str] = [] - total = 0 - for m in match_objects: - line_text = m.get("data", {}).get("lines", {}).get("text", "").rstrip() - if not line_text: + for fp in outline_target_files: + tp = file_total_pages.get(fp) + + # Strategy 1: LLM-analyzed TOC pages (highest quality) + toc_outline, toc_meta = await self._build_outline_from_toc_pages(fp, tp) + if toc_outline.strip(): + outlines[fp] = toc_outline + outlines_meta[fp] = toc_meta continue - snippets.append(line_text) - total += len(line_text) - if total >= self._FAST_MAX_EVIDENCE_CHARS: - break - if snippets: - snippet_evidence = f"[{fname}]\n" + "\n".join(snippets) - # If snippet evidence is too thin, try file head for richer context - if len(snippet_evidence) < 100: - await self._logger.info( - f"[FAST] Evidence too thin ({len(snippet_evidence)} chars) for {fname}, " - f"attempting file head extraction" + # Strategy 2: Tree-index section map (fallback) + tree = indexer.load_tree(fp) if indexer else None + if tree and tree.root: + outline, sec_meta = self._build_section_map( + tree.root, max_depth=self._AGENTIC_SECTION_MAP_DEPTH, ) - head_evidence = await self._fast_read_file_head(file_path) - if head_evidence and len(head_evidence) > len(snippet_evidence): - return head_evidence - return snippet_evidence + if outline.strip(): + outlines[fp] = outline + outlines_meta[fp] = sec_meta + continue - # Last resort: try reading file head - return await self._fast_read_file_head(file_path) + # Strategy 3: Sampled content outline for docs with known page count + if tp: + sampled_outline, sampled_meta = self._build_sampled_outline( + fp, tp, + ) + outlines[fp] = sampled_outline + outlines_meta[fp] = sampled_meta + + current_reqs = data_reqs + + if not outline_target_files and evidence_parts: + combined = "\n\n".join(evidence_parts) + return RetrievalResult( + evidence=combined[:self._AGENTIC_EVIDENCE_MAX_CHARS], + pages_extracted={ + fp: sorted(ps) for fp, ps in pages_extracted.items() + }, + is_complete=True, + rounds_used=0, + ) - @staticmethod - def _read_context_windows( - file_path: str, - hit_lines: List[int], - window: int = 30, - max_chars: int = 15_000, - ) -> Optional[str]: - """Read context windows around *hit_lines* from a text file. + for round_idx in range(self._AGENTIC_MAX_ROUNDS): + round_fetched_any = False - Merges overlapping windows to avoid duplication. Stops when - *max_chars* is reached. - """ - # Merge overlapping intervals - intervals = sorted(set( - (max(1, ln - window), ln + window) for ln in hit_lines - )) - merged: List[tuple] = [intervals[0]] - for start, end in intervals[1:]: - if start <= merged[-1][1] + 1: - merged[-1] = (merged[-1][0], max(merged[-1][1], end)) - else: - merged.append((start, end)) + for fp in outline_target_files: + if total_pages >= self._AGENTIC_MAX_TOTAL_PAGES: + break - # Read file and extract windows - try: - with open(file_path, "r", encoding="utf-8", errors="replace") as f: - all_lines = f.readlines() - except Exception: - return None + fname = Path(fp).name + fetched = pages_extracted.get(fp, set()) + outline = outlines.get(fp, "") + sec_meta = outlines_meta.get(fp) + tp = file_total_pages.get(fp) - parts: List[str] = [] - total = 0 - for start, end in merged: - s = max(0, start - 1) # 0-indexed - e = min(len(all_lines), end) - chunk = "".join(all_lines[s:e]) - if total + len(chunk) > max_chars: - remaining = max_chars - total - if remaining > 200: - chunk = chunk[:remaining] + "\n[...truncated...]" - parts.append(chunk) - break - parts.append(chunk) - total += len(chunk) + if not outline and not tp: + continue - if not parts: - return None + new_pages = await self._select_pages_for_data( + query, current_reqs, outline or "(no outline available)", + "\n\n".join(evidence_parts)[:8000], + fetched, + sections_meta=sec_meta, + total_pages=tp, + ) + new_pages -= fetched + if not new_pages: + continue + + budget = self._AGENTIC_MAX_PAGES_PER_ROUND + capped = sorted(new_pages)[:budget] + try: + contents = DocumentExtractor.extract_pages(fp, capped) + for pc in contents: + if pc.content and pc.content.strip(): + evidence_parts.append( + f"[{fname} p.{pc.page_number}]\n{pc.content}" + ) + pages_extracted.setdefault(fp, set()).update(capped) + total_pages += len(capped) + round_fetched_any = True + except Exception as exc: + await self._logger.warning( + f"[Phase 4] Page extraction failed for {fname}: {exc}" + ) + + # Append table digests for newly fetched pages only + try: + from sirchmunk.utils.file_utils import get_fast_hash + fhash = get_fast_hash(fp) + if fhash: + tables = self._load_table_digest(self.work_path, fhash) + if tables: + new_page_set = set(capped) + page_tables = [ + t for t in tables + if t.get("page_number") in new_page_set + ] + if page_tables: + table_ev = self._format_table_evidence( + page_tables, + max_chars=self._TABLE_EVIDENCE_DEFAULT_CHARS, + query=query, + ) + if table_ev: + evidence_parts.append( + f"[{fname} tables]\n{table_ev}" + ) + except Exception: + pass - # Join windows with separator when there are gaps - return "\n[...]\n".join(parts) + context.increment_loop() - @classmethod - async def _fast_read_file_head( - cls, file_path: str, max_chars: int = 8_000, - ) -> str: - """Read the head of a file as last-resort evidence.""" - try: - p = Path(file_path) - if p.suffix.lower() in cls._FAST_TEXT_EXTENSIONS: - text = p.read_text(encoding="utf-8", errors="replace") - else: - from sirchmunk.utils.file_utils import fast_extract - result = await fast_extract(file_path) - text = result.content if result and result.content else "" - if text: - return f"[{p.name}]\n{text[:max_chars]}" - except Exception: - pass - return "" + if not round_fetched_any: + break - @staticmethod - def _parse_fast_json(text: str) -> Dict[str, Any]: - """Extract JSON from the FAST query analysis LLM response.""" - text = text.strip() - try: - return json.loads(text) - except (json.JSONDecodeError, TypeError): - pass - cleaned = re.sub(r"^```(?:json)?\s*", "", text, flags=re.MULTILINE) - cleaned = re.sub(r"```\s*$", "", cleaned, flags=re.MULTILINE).strip() - try: - return json.loads(cleaned) - except (json.JSONDecodeError, TypeError): - pass - match = re.search(r"\{.*\}", text, re.DOTALL) - if match: - try: - return json.loads(match.group()) - except (json.JSONDecodeError, TypeError): - pass - return {} + combined = "\n\n".join(evidence_parts) + is_complete, missing = await self._check_data_requirements( + query, current_reqs, combined, + ) + context.increment_loop() - # ------------------------------------------------------------------ - # Phase 1 probes (each designed to run concurrently) - # ------------------------------------------------------------------ + await self._logger.info( + f"[Phase 4] Round {round_idx + 1}: " + f"{total_pages} pages, complete={is_complete}, " + f"missing={len(missing)}" + ) - async def _probe_keywords( - self, query: str, - ) -> Tuple[Dict[str, float], List[str]]: - """Extract multi-level keywords from the query via LLM. + if is_complete or not missing: + return RetrievalResult( + evidence=combined[:self._AGENTIC_EVIDENCE_MAX_CHARS], + pages_extracted={ + fp: sorted(ps) for fp, ps in pages_extracted.items() + }, + is_complete=True, + rounds_used=round_idx + 1, + ) - Also extracts cross-lingual alternative keywords from the - ```` block and merges them into the result list. + current_reqs = DataRequirements( + data_points=missing, + likely_sources=data_reqs.likely_sources, + formula=data_reqs.formula, + time_period=data_reqs.time_period, + intent=data_reqs.intent, + ) - Returns: - Tuple of (keyword_idf_dict, keyword_list). - """ - await self._logger.info("[Probe:Keywords] Extracting keywords...") - dynamic_prompt = generate_keyword_extraction_prompt(num_levels=2) - keyword_prompt = dynamic_prompt.replace(KEYWORD_QUERY_PLACEHOLDER, query) - kw_response = await self.llm.achat( - messages=[{"role": "user", "content": keyword_prompt}], - stream=False, + combined = "\n\n".join(evidence_parts) + return RetrievalResult( + evidence=combined[:self._AGENTIC_EVIDENCE_MAX_CHARS], + pages_extracted={ + fp: sorted(ps) for fp, ps in pages_extracted.items() + }, + is_complete=False, + rounds_used=self._AGENTIC_MAX_ROUNDS, ) - self.llm_usages.append(kw_response.usage) - keyword_sets = self._extract_and_validate_multi_level_keywords( - kw_response.content, num_levels=2, + async def _synthesize_from_retrieval( + self, + query: str, + intent: str, + retrieval: RetrievalResult, + file_paths: List[str], + formula: Optional[str] = None, + ) -> Tuple[str, bool, Optional["KnowledgeCluster"]]: + """Synthesize final answer from agentic retrieval evidence.""" + if not retrieval.evidence.strip(): + return _NO_RESULTS_MESSAGE, False, None + + evidence = retrieval.evidence + if formula and intent == "computation": + evidence = f"[Required Formula: {formula}]\n\n{evidence}" + + synth_prompt = self._select_synthesis_prompt( + query, evidence, intent, + ) + resp = await self.llm.achat( + messages=[{"role": "user", "content": synth_prompt}], + stream=True, ) + self.llm_usages.append(resp.usage) - alt_keywords = self._extract_alt_keywords(kw_response.content) - if alt_keywords: - await self._logger.info(f"[Probe:Keywords] Cross-lingual alt: {list(alt_keywords.keys())}") + raw = resp.content or "" + answer, should_save, should_answer = self._parse_summary_response(raw) - for kw_set in keyword_sets: - if kw_set: - merged = {**kw_set, **alt_keywords} - kw_list = list(merged.keys()) - await self._logger.info(f"[Probe:Keywords] Extracted: {kw_list}") - return merged, kw_list + accepted, reason = self._evaluate_evidence_acceptance( + query, retrieval.evidence, should_answer, + retrieval_complete=retrieval.is_complete, + ) + await self._logger.info( + f"[Phase 4.5] Synthesis: accepted={accepted} ({reason})" + ) - if alt_keywords: - return alt_keywords, list(alt_keywords.keys()) + if not accepted: + return _NO_RESULTS_MESSAGE, False, None - return {}, [] + cluster = self._make_answer_cluster( + query, retrieval.evidence[:5000], "AGT", + file_paths=list(retrieval.pages_extracted.keys())[:3], + ) + cluster.content = retrieval.evidence + return answer, should_save, cluster - @staticmethod - def _has_directory_paths(paths: List[str]) -> bool: - """Return True if any element in *paths* is a directory.""" - return any(Path(p).is_dir() for p in paths) + # ------------------------------------------------------------------ + # LLM-powered document outline from TOC pages + # ------------------------------------------------------------------ - @staticmethod - def _resolve_file_hints( - paths: List[str], - file_hints: List[str], - max_depth: int = 8, - ) -> List[str]: - """Resolve file_hints (filenames) to absolute paths under *paths*. + _TOC_ANALYSIS_PAGES: List[int] = [1, 2, 3, 4, 5] - Lightweight name-only search: no metadata extraction. Used when the - user clearly asks for a specific document (e.g. "总结《foo.pdf》") - so we can skip full dir scan + LLM rank. + async def _build_outline_from_toc_pages( + self, + file_path: str, + total_pages: Optional[int] = None, + ) -> Tuple[str, List[Dict[str, Any]]]: + """Build a section map by extracting and LLM-analyzing TOC pages. - Returns: - List of absolute path strings that match any hint (deduplicated, - order preserved). Empty if no matches. - """ - if not file_hints: - return [] + Extracts the first few pages of a PDF (where the Table of Contents + typically resides), sends the text to the LLM for structural parsing, + and returns an outline string plus section metadata in the same format + as ``_build_section_map()`` for seamless integration. - hints = [h.strip() for h in file_hints if (h and isinstance(h, str))] - if not hints: - return [] + Results are cached per file hash to avoid repeated LLM calls. + """ + from sirchmunk.utils.file_utils import get_fast_hash - def _name_matches(name: str, hint: str) -> bool: - name_n = name.strip() - hint_n = hint.strip() - if not hint_n: - return False - if name_n == hint_n: - return True - if hint_n.lower() in name_n.lower(): - return True - if Path(name_n).stem == Path(hint_n).stem: - return True - return False + fhash = get_fast_hash(file_path) + if not fhash: + return "", [] - seen: set = set() - out: List[str] = [] + cache_dir = self.work_path / ".cache" / "compile" / "toc_outlines" + cache_path = cache_dir / f"{fhash}.json" - def walk_dir(d: Path, depth: int) -> None: - if depth > max_depth or len(out) >= 20: - return + sections_raw: Optional[list] = None + if cache_path.exists(): try: - for entry in sorted(d.iterdir(), key=lambda p: p.name): - if len(out) >= 20: - return - if entry.name.startswith("."): - continue - if entry.is_file(): - for hint in hints: - if _name_matches(entry.name, hint): - resolved = str(entry.resolve()) - if resolved not in seen: - seen.add(resolved) - out.append(resolved) - break - elif entry.is_dir(): - walk_dir(entry, depth + 1) - except PermissionError: + sections_raw = json.loads(cache_path.read_text()) + except Exception: pass - for p_str in paths: - p = Path(p_str).resolve() - if p.is_file(): - for hint in hints: - if _name_matches(p.name, hint): - resolved = str(p) - if resolved not in seen: - seen.add(resolved) - out.append(resolved) - break - elif p.is_dir(): - walk_dir(p, 0) + if sections_raw is None: + try: + contents = DocumentExtractor.extract_pages( + file_path, self._TOC_ANALYSIS_PAGES, + ) + toc_text = "\n\n".join( + f"--- Page {pc.page_number} ---\n{pc.content}" + for pc in contents if pc.content and pc.content.strip() + ) + if len(toc_text.strip()) < 200: + return "", [] - return out + tp = total_pages or len(contents) + prompt = DEEP_TOC_ANALYSIS.format( + toc_page_text=toc_text[:12000], + total_pages=tp, + ) + resp = await self.llm.achat( + [{"role": "user", "content": prompt}], stream=False, + ) + self.llm_usages.append(resp.usage) - async def _probe_dir_scan( - self, - paths: List[str], - enable: bool = True, - max_files: int = 500, - ): - """Scan directories for file metadata (filesystem only, no LLM). + raw = (resp.content or "").strip() + sections_raw = self._extract_json_array(raw) + if sections_raw is None: + sections_raw = [] - Automatically skips scanning when all *paths* are single files. + cache_dir.mkdir(parents=True, exist_ok=True) + cache_path.write_text(json.dumps(sections_raw, ensure_ascii=False)) + except Exception as exc: + await self._logger.warning( + f"[Phase 4] TOC outline extraction failed: {exc}" + ) + return "", [] - Args: - paths: Normalised list of path strings to scan. - enable: Whether directory scanning is enabled. - max_files: Cap on number of files to scan (lower = faster). + if not sections_raw: + return "", [] - Returns: - ScanResult or None if disabled / all paths are files. - """ - if not enable or not self._has_directory_paths(paths): - return None + return self._toc_sections_to_outline(sections_raw, total_pages) - from sirchmunk.scan.dir_scanner import DirectoryScanner + @staticmethod + def _toc_sections_to_outline( + sections_raw: list, + total_pages: Optional[int] = None, + ) -> Tuple[str, List[Dict[str, Any]]]: + """Convert raw TOC section list to outline string and metadata.""" + sections_meta: List[Dict[str, Any]] = [] + + for i, sec in enumerate(sections_raw): + if not isinstance(sec, dict): + continue + title = str(sec.get("title", "")).strip() + if not title: + continue - if self._dir_scanner is None or self._dir_scanner.max_files != max_files: - self._dir_scanner = DirectoryScanner(llm=self.llm, max_files=max_files) + ps = sec.get("page_start") + pe = sec.get("page_end") + level = int(sec.get("level", 1)) - 1 + + page_range = None + if ps is not None: + ps = int(ps) + pe = int(pe) if pe is not None else (total_pages or ps) + page_range = [ps, pe] + + idx = len(sections_meta) + sections_meta.append({ + "idx": idx, + "title": title, + "page_range": page_range, + "char_range": None, + "depth": level, + "node_id": f"toc_{idx}", + "summary": "", + }) + + # Post-process: fix page_range errors from LLM inference + for i, sec in enumerate(sections_meta): + pr = sec.get("page_range") + if not pr: + continue + needs_fix = pr[1] < pr[0] + if not needs_fix and pr[1] == pr[0] and i + 1 < len(sections_meta): + next_pr = sections_meta[i + 1].get("page_range") + needs_fix = next_pr and next_pr[0] == pr[0] + if needs_fix: + for j in range(i + 1, len(sections_meta)): + next_pr = sections_meta[j].get("page_range") + if next_pr and next_pr[0] > pr[0]: + pr[1] = next_pr[0] - 1 + break + else: + pr[1] = total_pages or pr[0] - await self._logger.info("[Probe:DirScan] Scanning directories...") - scan_result = await self._dir_scanner.scan(paths) - await self._logger.info( - f"[Probe:DirScan] Found {scan_result.total_files} files " - f"in {scan_result.total_dirs} dirs ({scan_result.scan_duration_ms:.0f}ms)" - ) - return scan_result + lines: List[str] = [] + for sec in sections_meta: + pr = sec.get("page_range") + indent = " " * sec["depth"] + page_str = f"(p{pr[0]}-{pr[1]})" if pr else "" + lines.append(f"[{sec['idx']}] {indent}{sec['title']} {page_str}") - async def _probe_knowledge_cache( - self, query: str, - ) -> List[str]: - """Search knowledge cache for related clusters, return known file paths. + return "\n".join(lines), sections_meta - Returns: - List of file paths from previously cached clusters. + @staticmethod + def _build_sampled_outline( + file_path: str, + total_pages: int, + interval: int = 10, + ) -> Tuple[str, List[Dict[str, Any]]]: + """Build an outline by sampling page content at regular intervals. + + Used as a fallback when TOC parsing and tree indices are unavailable. + Gives the LLM enough context to make informed page selections. """ - try: - clusters = await self.knowledge_storage.find(query, limit=3) - if not clusters: - return [] + sample_pages = list(range(1, total_pages + 1, interval)) + if total_pages not in sample_pages: + sample_pages.append(total_pages) - file_paths: List[str] = [] - for c in clusters: - for ev in getattr(c, "evidences", []): - fp = str(getattr(ev, "file_or_url", "")) - if fp and Path(fp).exists(): - file_paths.append(fp) + sections_meta: List[Dict[str, Any]] = [] + lines: List[str] = [] - if file_paths: - await self._logger.info( - f"[Probe:Knowledge] Found {len(file_paths)} files from cached clusters" - ) - return file_paths + try: + contents = DocumentExtractor.extract_pages(file_path, sample_pages) + page_snippets = { + pc.page_number: (pc.content or "").strip()[:200] + for pc in contents if pc.content + } except Exception: - return [] + page_snippets = {} - @staticmethod - async def _async_noop(default=None): - """No-op coroutine used as placeholder in gather().""" - return default + for i, pg in enumerate(sample_pages): + snippet = page_snippets.get(pg, "") + snippet_clean = " ".join(snippet.split())[:150] + next_pg = sample_pages[i + 1] if i + 1 < len(sample_pages) else total_pages + page_range = [pg, next_pg] - # ------------------------------------------------------------------ - # Phase 2 retrievers - # ------------------------------------------------------------------ + title = f'p{pg}: "{snippet_clean}..."' if snippet_clean else f"p{pg}" + sections_meta.append({ + "idx": i, "title": title, + "page_range": page_range, "char_range": None, + "depth": 0, "node_id": f"sample_{i}", "summary": "", + }) + lines.append(f"[{i}] {title} (p{page_range[0]}-{page_range[1]})") - async def _retrieve_by_keywords( - self, - keywords: List[str], - paths: List[str], - max_depth: Optional[int] = 5, - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - ) -> List[str]: - """Run keyword search via rga and return discovered file paths. + return "\n".join(lines), sections_meta - Each keyword is searched concurrently (literal per-term strategy). - """ - from sirchmunk.agentic.tools import KeywordSearchTool + # ------------------------------------------------------------------ + # Deep Structured Reasoning pipeline (legacy, used by older code paths) + # ------------------------------------------------------------------ - tool = KeywordSearchTool( - retriever=self.grep_retriever, - paths=paths, - max_depth=max_depth if max_depth is not None else 5, - max_results=20, - include=include, - exclude=exclude, - ) - ctx = SearchContext() # lightweight context for this probe - result_text, meta = await tool.execute(context=ctx, keywords=keywords) + @staticmethod + def _build_section_map( + root: Any, + max_depth: int = 2, + ) -> Tuple[str, List[Dict[str, Any]]]: + """Build a lightweight section map from the top layers of a tree index. - # Extract discovered file paths from the tool's context logs - discovered: List[str] = [] - for log_entry in ctx.retrieval_logs: - discovered.extend(log_entry.metadata.get("files_discovered", [])) + Args: + root: A ``TreeNode`` root from a ``DocumentTree``. - await self._logger.info( - f"[Retrieve:Keywords] {len(discovered)} files from rga search" - ) - return discovered + Returns a human-readable map string (with numbered indices so the LLM + can reference specific sections) and a parallel list of section + metadata dicts for programmatic use. + """ + sections: List[Dict[str, Any]] = [] - async def _rank_dir_scan_candidates( + def _walk(node: Any, depth: int) -> None: + if depth > max_depth: + return + pr = node.page_range + idx = len(sections) + sections.append({ + "idx": idx, + "title": node.title, + "page_range": list(pr) if pr else None, + "char_range": list(node.char_range) if getattr(node, "char_range", None) else None, + "depth": depth, + "node_id": node.node_id, + "summary": (node.summary or "")[:120], + }) + for child in node.children: + _walk(child, depth + 1) + + children = root.children if root.children else [root] + while len(children) == 1 and children[0].children and not children[0].leaf: + children = children[0].children + + for child in children: + _walk(child, 0) + + map_lines: List[str] = [] + for sec in sections: + indent = " " * sec["depth"] + pr = sec.get("page_range") + page_str = f"(p{pr[0]}-{pr[1]})" if pr and pr[0] else "" + map_lines.append(f"[{sec['idx']}] {indent}{sec['title']} {page_str}") + + return "\n".join(map_lines), sections + + async def _select_evidence_sections( self, query: str, - scan_result, - *, - top_k: int = 20, - include_medium: bool = False, - ) -> List[str]: - """Run LLM ranking on dir_scan candidates and return relevant paths. + section_map: str, + sections_meta: List[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + """LLM-driven selection of relevant sections from a section map. - Args: - include_medium: When True, include both high and medium relevance. + Returns the metadata dicts for the selected sections. """ - if self._dir_scanner is None: - return [] - - ranked = await self._dir_scanner.rank(query, scan_result, top_k=top_k) - accept = {"high", "medium"} if include_medium else {"high"} - paths = [ - c.path for c in ranked.ranked_candidates - if c.relevance in accept - ] - await self._logger.info( - f"[Retrieve:DirScan] {len(paths)} relevant files " - f"(accept={accept})" + prompt = DEEP_SECTION_SELECT.format( + query=query, + section_map=section_map, ) - return paths + resp = await self.llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=False, + ) + self.llm_usages.append(resp.usage) - async def _scan_and_rank_paths( + raw = (resp.content or "").strip() + # Parse JSON array of indices + try: + match = re.search(r"\[[\s\d,]*\]", raw) + if match: + indices = json.loads(match.group(0)) + return [ + sections_meta[i] + for i in indices + if isinstance(i, int) and 0 <= i < len(sections_meta) + ] + except (json.JSONDecodeError, IndexError): + pass + + # Fallback: return sections that have page_range data + return [s for s in sections_meta if s.get("page_range")][:3] + + async def _extract_targeted_pages( self, + file_path: str, + selected_sections: List[Dict[str, Any]], query: str, - paths: List[str], - *, - max_files: int = 300, - top_k: int = 20, - include_medium: bool = True, - ) -> List[str]: - """Scan directories and return LLM-ranked relevant file paths. + ) -> str: + """Extract content for LLM-selected sections. - Combines :meth:`_probe_dir_scan` (filesystem walk) and - :meth:`_rank_dir_scan_candidates` (LLM ranking) in one call. - Automatically skips scanning when all *paths* are single files. + Two extraction strategies (tried in order): + 1. **Page-based** — ``DocumentExtractor.extract_pages`` for PDFs. + 2. **Char-range** — direct text slice from compile cache or + fast_extract for any file type. - Returns: - Ranked file paths (high + optionally medium relevance), - or empty list when scanning is not applicable. + Table digests are appended when available. Caps output at + ``_DEEP_STRUCTURED_MAX_CHARS``. """ - scan_result = await self._probe_dir_scan( - paths, enable=True, max_files=max_files, - ) - if scan_result is None: - return [] + parts: List[str] = [] - return await self._rank_dir_scan_candidates( - query, scan_result, - top_k=top_k, include_medium=include_medium, - ) + # Strategy 1: page-based extraction (PDF) + pages_needed: Set[int] = set() + for sec in selected_sections: + pr = sec.get("page_range") + if pr and len(pr) == 2 and pr[0]: + pages_needed.update(range( + max(1, pr[0] - self._NAV_PAGE_MARGIN), + pr[1] + self._NAV_PAGE_MARGIN + 1, + )) + + if pages_needed: + sorted_pages = sorted(pages_needed)[: self._DEEP_MAX_EXTRACT_PAGES] + try: + page_contents = DocumentExtractor.extract_pages( + file_path, sorted_pages, + ) + for pc in page_contents: + if pc.content and pc.content.strip(): + parts.append(f"[Page {pc.page_number}]\n{pc.content}") + except Exception as e: + await self._logger.warning( + f"[DeepStructured] Page extraction failed for " + f"{Path(file_path).name}: {e}" + ) - # ------------------------------------------------------------------ - # Phase 3: Merge + cluster build - # ------------------------------------------------------------------ + # Strategy 2: char_range fallback (non-PDF or when pages failed) + if not parts: + full_text = self._load_compile_content(self.work_path, file_path) + if not full_text: + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" + if full_text: + for sec in selected_sections: + cr = sec.get("char_range") + if cr and len(cr) == 2 and cr[0] is not None: + start, end = cr + if 0 <= start < end <= len(full_text): + segment = full_text[start:end] + if segment.strip(): + parts.append( + f"[{sec.get('title', 'Section')}]\n{segment}" + ) - @staticmethod - def _merge_file_paths( - keyword_files: List[str], - dir_scan_files: List[str], - knowledge_hits: List[str], - ) -> List[str]: - """Merge file paths from all retrieval paths, dedup, preserve priority. + # Append relevant table digests when available + if pages_needed: + try: + from sirchmunk.utils.file_utils import get_fast_hash + fhash = get_fast_hash(file_path) + if fhash: + tables = self._load_table_digest(self.work_path, fhash) + if tables: + page_tables = [ + t for t in tables + if t.get("page_number") in pages_needed + ] + if page_tables: + table_ev = self._format_table_evidence( + page_tables, + max_chars=self._TABLE_EVIDENCE_DEFAULT_CHARS, + query=query, + ) + if table_ev: + parts.append(f"[Table Evidence]\n{table_ev}") + except Exception: + pass - Priority: keyword_search > knowledge_cache > dir_scan. + evidence = "\n\n".join(parts) + return evidence[: self._DEEP_STRUCTURED_MAX_CHARS] + + async def _deep_structured_reasoning( + self, + query: str, + tree_files: List[str], + artifacts: Any, + context: "SearchContext", + intent: str = "", + ) -> Tuple[str, Optional["KnowledgeCluster"], str]: + """Orchestrate the Deep Structured Reasoning pipeline. + + Phases: + 1. Section map — build from tree index top layers (no LLM) + 2. Section select — LLM picks relevant sections (1 LLM) + 3. Targeted extraction — pull pages + tables for sections (no LLM) + 4. Synthesis — intent-aware prompt on targeted evidence (1 LLM) + 5. Recovery — if refused, expand sections and re-synthesize + + Returns ``(raw_llm_output, cluster, combined_evidence)`` where + *combined_evidence* is the raw document text fed to the LLM so + callers can use it for evidence-acceptance checks instead of + the LLM's answer text. """ - seen: set = set() - merged: List[str] = [] + indexer = self._get_tree_indexer() + if indexer is None: + return "", None, "" - for fp in keyword_files + knowledge_hits + dir_scan_files: - if fp and fp not in seen: - seen.add(fp) - merged.append(fp) + all_evidence_parts: List[str] = [] - return merged + for fp in tree_files[: self._DEEP_STRUCTURED_MAX_FILES]: + fname = Path(fp).name + tree = indexer.load_tree(fp) + if tree is None or tree.root is None: + continue - async def _build_cluster( - self, - query: str, - file_paths: List[str], - query_keywords: Dict[str, float], - top_k_files: int = 5, - top_k_snippets: int = 5, - ) -> Optional[KnowledgeCluster]: - """Build a KnowledgeCluster via knowledge_base.build(). + section_map, sections_meta = self._build_section_map( + tree.root, max_depth=self._DEEP_SECTION_MAP_MAX_DEPTH, + ) + if not sections_meta: + continue - Constructs the Request wrapper and delegates to the knowledge - base for parallel Monte Carlo evidence sampling. - """ - try: - request = Request( - messages=[ - Message( - role="user", - content=[ContentItem(type="text", text=query)], - ), - ], + await self._logger.info( + f"[DeepSR] Section map for {fname}: " + f"{len(sections_meta)} sections" ) - retrieved_infos = [{"path": fp} for fp in file_paths] - cluster = await self.knowledge_base.build( - request=request, - retrieved_infos=retrieved_infos, - keywords=query_keywords, - top_k_files=top_k_files, - top_k_snippets=top_k_snippets, - verbose=self.verbose, + selected = await self._select_evidence_sections( + query, section_map, sections_meta, ) - self.llm_usages.extend(self.knowledge_base.llm_usages) - self.knowledge_base.llm_usages.clear() + context.increment_loop() + if not selected: + continue - if cluster: - await self._logger.success( - f"[Phase 3] KnowledgeCluster built: {cluster.name} " - f"({len(cluster.evidences)} evidence units)" - ) - return cluster - except Exception as exc: - await self._logger.warning(f"[Phase 3] knowledge_base.build() failed: {exc}") - return None + await self._logger.info( + f"[DeepSR] Selected {len(selected)} sections: " + f"{[s['title'][:30] for s in selected]}" + ) - # ------------------------------------------------------------------ - # Phase 4: Answer generation - # ------------------------------------------------------------------ + raw_evidence = await self._extract_targeted_pages( + fp, selected, query, + ) + if not raw_evidence: + continue - async def _summarise_cluster( - self, query: str, cluster: KnowledgeCluster, - ) -> Tuple[str, bool, bool]: - """Generate a final answer summary from a KnowledgeCluster. + await self._logger.info( + f"[DeepSR] Extracted {len(raw_evidence)} chars from {fname}" + ) - Returns: - ``(summary_text, should_save, should_answer)`` where: - - should_save: quality verdict for persistence - - should_answer: evidence sufficiency verdict for answering - """ - sep = "\n" - cluster_text_content = ( - f"{cluster.name}\n\n" - f"{sep.join(cluster.description)}\n\n" - f"{cluster.content if isinstance(cluster.content, str) else sep.join(cluster.content)}" - ) + all_evidence_parts.append(f"[Source: {fname}]\n{raw_evidence}") - result_sum_prompt = SEARCH_RESULT_SUMMARY.format( - user_input=query, - text_content=cluster_text_content, + if not all_evidence_parts: + return "", None, "" + + combined_evidence = "\n\n---\n\n".join(all_evidence_parts) + + # Build document context from artifacts when available + doc_context: Optional[str] = None + if artifacts and artifacts.catalog_map: + ctx_parts = [ + self._build_answer_context(fp, artifacts) + for fp in tree_files[: self._DEEP_STRUCTURED_MAX_FILES] + ] + ctx_parts = [c for c in ctx_parts if c] + if ctx_parts: + doc_context = "\n".join(ctx_parts) + + # Synthesize answer using intent-aware prompt + synth_prompt = self._select_synthesis_prompt( + query, combined_evidence, intent, + document_context=doc_context, ) - await self._logger.info("[Phase 4] Generating search result summary...") - response = await self.llm.achat( - messages=[{"role": "user", "content": result_sum_prompt}], + resp = await self.llm.achat( + messages=[{"role": "user", "content": synth_prompt}], stream=True, ) - self.llm_usages.append(response.usage) - - summary, should_save, should_answer = self._parse_summary_response(response.content) - return summary, should_save, should_answer + self.llm_usages.append(resp.usage) + context.increment_loop() - async def _summarise_cluster_fallback(self, query: str) -> Tuple[str, bool]: - """Generate an answer using the DEEP summary prompt with fallback evidence. + raw_response = resp.content or "" + _, _, should_answer = self._parse_summary_response(raw_response) - Reuses the existing ``SEARCH_RESULT_SUMMARY`` prompt, feeding it the - standard fallback text so that the LLM answers from its own knowledge - without adding an extra LLM call to the pipeline. - """ - result_sum_prompt = SEARCH_RESULT_SUMMARY.format( - user_input=query, - text_content=self._LLM_FALLBACK_EVIDENCE, + await self._logger.info( + f"[DeepSR] Synthesis complete: should_answer={should_answer}, " + f"len={len(raw_response)}" ) - await self._logger.info("[Phase 4] Generating fallback summary from LLM knowledge...") - response = await self.llm.achat( - messages=[{"role": "user", "content": result_sum_prompt}], - stream=True, + + # Recovery: if the answer is a refusal, try expanding sections + if (not should_answer or self._is_refusal_answer(raw_response[:500])): + for recovery_round in range(1, self._DEEP_MAX_RECOVERY_ROUNDS + 1): + await self._logger.info( + f"[DeepSR] Recovery round {recovery_round}" + ) + expanded_parts: List[str] = list(all_evidence_parts) + found_new = False + for fp in tree_files[: self._DEEP_STRUCTURED_MAX_FILES]: + tree = indexer.load_tree(fp) + if tree is None or tree.root is None: + continue + section_map, sections_meta = self._build_section_map( + tree.root, + max_depth=self._DEEP_SECTION_MAP_MAX_DEPTH + recovery_round, + ) + if not sections_meta: + continue + recovery_selected = await self._select_evidence_sections( + query, section_map, sections_meta, + ) + context.increment_loop() + if not recovery_selected: + continue + recovery_ev = await self._extract_targeted_pages( + fp, recovery_selected, query, + ) + if recovery_ev and recovery_ev not in combined_evidence: + expanded_parts.append( + f"[Recovery source: {Path(fp).name}]\n{recovery_ev}" + ) + found_new = True + if not found_new: + break + combined_evidence = "\n\n---\n\n".join(expanded_parts) + synth_prompt = self._select_synthesis_prompt( + query, + combined_evidence[:self._DEEP_STRUCTURED_MAX_CHARS], + intent, + document_context=doc_context, + ) + resp = await self.llm.achat( + messages=[{"role": "user", "content": synth_prompt}], + stream=True, + ) + self.llm_usages.append(resp.usage) + context.increment_loop() + raw_response = resp.content or "" + _, _, should_answer = self._parse_summary_response(raw_response) + if should_answer and not self._is_refusal_answer( + raw_response[:500] + ): + break + + cluster = self._make_answer_cluster( + query, combined_evidence[:5000], "DSR", + file_paths=tree_files[: self._DEEP_STRUCTURED_MAX_FILES], ) - self.llm_usages.append(response.usage) - summary, _, _ = self._parse_summary_response(response.content) - return summary, False # Never save fallback answers - async def _summarise_fast_fallback( - self, query: str, context: "SearchContext", - ) -> Tuple[str, bool]: - """Generate an answer using the FAST summary prompt with fallback evidence. + return raw_response, cluster, combined_evidence - Reuses the existing ``ROI_RESULT_SUMMARY`` prompt, feeding it the - standard fallback text so that the LLM answers from its own knowledge. + async def _deep_self_correct( + self, + query: str, + merged_files: List[str], + query_keywords: Dict[str, float], + context: "SearchContext", + ) -> Optional[str]: + """Gather alternative evidence when DEEP Phase 4 answer is rejected. + + Four strategies tried in order, stopping at first success: + A) Expanded tree-guided sampling on the primary file. + B) rga keyword window extraction on primary files using + Phase-1 keywords (reuses the rga infrastructure). + C) Semantically similar cluster from knowledge storage. + D) Tree-guided sampling on secondary merged files. + + Returns alternative evidence text, or ``None`` when every + strategy fails. """ - answer_prompt = ROI_RESULT_SUMMARY.format( - user_input=query, - text_content=self._LLM_FALLBACK_EVIDENCE, - ) - answer_resp = await self.llm.achat( - messages=[{"role": "user", "content": answer_prompt}], - stream=True, - ) - self.llm_usages.append(answer_resp.usage) - if answer_resp.usage and isinstance(answer_resp.usage, dict): - context.add_llm_tokens( - answer_resp.usage.get("total_tokens", 0), usage=answer_resp.usage, + primary_files = merged_files[:2] + secondary_files = merged_files[2:5] + + # Strategy A: expanded tree sampling on primary file + for fp in primary_files: + expanded_ev = await self._tree_guided_sample( + fp, query, + max_chars=self._FAST_MAX_EVIDENCE_CHARS * 2, ) - answer, _, _ = self._parse_summary_response(answer_resp.content or "") - return answer, False # Never save fallback answers + if isinstance(expanded_ev, str) and len(expanded_ev.strip()) > 100: + await self._logger.info( + "[DEEP:SelfCorrect] Strategy A succeeded: " + f"expanded tree sample from {Path(fp).name}" + ) + return expanded_ev + + # Strategy B: tree-navigated evidence with expanded parameters + for fp in primary_files: + try: + nav_ev = await self._navigate_tree_for_evidence( + fp, query, + max_results=self._SELF_CORRECT_EXPANDED_NAV_RESULTS, + ) + if nav_ev and len(nav_ev.strip()) > 100: + await self._logger.info( + "[DEEP:SelfCorrect] Strategy B succeeded: " + f"expanded tree navigation on {Path(fp).name}" + ) + return nav_ev + except Exception: + pass + + # Strategy C: semantically similar cluster from knowledge storage + if self.embedding_client and self.knowledge_storage: + try: + qe = self.embedding_client.encode(query) + if qe is not None: + vec = qe.tolist() if hasattr(qe, "tolist") else list(qe) + hits = await self.knowledge_storage.search_similar_clusters( + query_embedding=vec, top_k=2, similarity_threshold=0.50, + ) + if hits: + parts: List[str] = [] + for h in hits[:2]: + c = await self.knowledge_storage.get(h["id"]) + if c and c.content: + parts.append(str(c.content)[:3000]) + for ev in (c.evidences or [])[:3]: + for s in (ev.snippets or [])[:2]: + parts.append(s[:500]) + if parts: + await self._logger.info( + "[DEEP:SelfCorrect] Strategy C succeeded: " + "knowledge storage cluster" + ) + return "\n\n---\n\n".join(parts) + except Exception: + pass + + # Strategy D: tree sampling on secondary files + for fp in secondary_files: + tree_ev = await self._tree_guided_sample( + fp, query, + max_chars=self._FAST_MAX_EVIDENCE_CHARS, + ) + if isinstance(tree_ev, str) and len(tree_ev.strip()) > 100: + context.mark_file_read(fp) + await self._logger.info( + "[DEEP:SelfCorrect] Strategy D succeeded: " + f"secondary file {Path(fp).name}" + ) + return tree_ev + + await self._logger.info("[DEEP:SelfCorrect] All strategies exhausted") + return None async def _react_refinement( self, diff --git a/src/sirchmunk/storage/knowledge_storage.py b/src/sirchmunk/storage/knowledge_storage.py index 0a99168..c74e05a 100644 --- a/src/sirchmunk/storage/knowledge_storage.py +++ b/src/sirchmunk/storage/knowledge_storage.py @@ -107,6 +107,12 @@ def _load_from_parquet(self): variable-length ``FLOAT[]`` from Parquet's list encoding, breaking ``list_cosine_similarity`` which requires matching fixed-size types. + Handles schema evolution gracefully with adaptive column mapping: + - Forward compatible: old parquet (more cols) → new table (fewer cols), + extra columns in parquet are silently ignored. + - Backward compatible: new parquet (fewer cols) → old table (more cols), + missing columns are filled with defaults. + Also records the file's modification time so that ``_check_and_reload()`` can detect external changes later. """ @@ -117,11 +123,63 @@ def _load_from_parquet(self): self.db.drop_table(self.table_name, if_exists=True) # Create table with explicit schema (preserves FLOAT[384]) self._create_table() - # Insert data from parquet — DuckDB casts to the declared types - self.db.execute( - f"INSERT INTO {self.table_name} " - f"SELECT * FROM read_parquet('{self.parquet_file}')" - ) + + # Adaptive column mapping: detect parquet & table columns + parquet_cols = self._get_parquet_columns(self.parquet_file) + table_cols = self._get_table_columns() + + if not parquet_cols or not table_cols: + logger.warning( + "Could not detect columns for adaptive mapping, " + "skipping parquet load" + ) + else: + parquet_col_set = set(parquet_cols) + table_col_set = set(table_cols) + # Compute common columns (preserve table column order) + common_cols = [c for c in table_cols if c in parquet_col_set] + + if not common_cols: + logger.warning( + "No common columns between parquet and table, " + "skipping parquet load" + ) + else: + # Log column mismatches as warnings + ignored_cols = parquet_col_set - table_col_set + missing_cols = table_col_set - parquet_col_set + if ignored_cols: + logger.warning( + "Parquet has extra columns (ignored): %s", + ignored_cols, + ) + if missing_cols: + logger.warning( + "Table has extra columns (filled with defaults): %s", + missing_cols, + ) + + # Build INSERT with explicit column lists + # For common cols: select directly from parquet + # For missing cols (in table but not in parquet): use defaults + insert_cols = list(table_cols) # all table columns + select_parts = [] + for col_name in table_cols: + if col_name in parquet_col_set: + select_parts.append(col_name) + elif col_name == "merge_count": + select_parts.append("0 AS merge_count") + else: + select_parts.append(f"NULL AS {col_name}") + + cols_str = ", ".join(insert_cols) + select_clause = ", ".join(select_parts) + self.db.execute( + f"INSERT INTO {self.table_name} ({cols_str}) " + f"SELECT {select_clause} " + f"FROM read_parquet('{self.parquet_file}')" + ) + count = self.db.get_table_count(self.table_name) # Record mtime for stale-detection self._parquet_loaded_mtime = pq.stat().st_mtime @@ -132,12 +190,63 @@ def _load_from_parquet(self): self._parquet_loaded_mtime = 0.0 logger.info("Created new knowledge clusters table") except Exception as e: - logger.error(f"Failed to load from parquet: {e}") - # Try to recreate table - self.db.drop_table(self.table_name, if_exists=True) - self._create_table() + logger.warning(f"Failed to load from parquet (non-blocking): {e}") + # Try to recreate table so retrieval can still work + try: + self.db.drop_table(self.table_name, if_exists=True) + self._create_table() + except Exception as recreate_err: + logger.warning(f"Failed to recreate table after load failure: {recreate_err}") self._parquet_loaded_mtime = 0.0 + def _get_schema_columns(self) -> List[str]: + """Return the ordered list of column names in the canonical schema.""" + return [ + "id", "name", "description", "content", "scripts", "resources", + "evidences", "patterns", "constraints", "confidence", + "abstraction_level", "landmark_potential", "hotness", "lifecycle", + "create_time", "last_modified", "version", "related_clusters", + "search_results", "queries", "merge_count", + "embedding_vector", "embedding_model", "embedding_timestamp", + "embedding_text_hash", + ] + + def _get_parquet_columns(self, parquet_path: str) -> List[str]: + """Get column names from a parquet file's schema. + + Uses DuckDB's ``parquet_schema()`` function. The returned metadata + rows use a ``name`` field (not ``column_name``). + + Returns: + Ordered list of column names, or empty list on failure. + """ + try: + rows = self.db.fetch_all( + f"SELECT name FROM parquet_schema('{parquet_path}') " + f"WHERE name != 'duckdb_schema'" + ) + return [row[0] for row in rows] + except Exception as e: + logger.warning(f"Failed to read parquet schema: {e}") + return [] + + def _get_table_columns(self) -> List[str]: + """Get column names from the current DuckDB table. + + Returns: + Ordered list of column names, or empty list on failure. + """ + try: + rows = self.db.fetch_all( + "SELECT column_name FROM information_schema.columns " + f"WHERE table_name = '{self.table_name}' " + "ORDER BY ordinal_position" + ) + return [row[0] for row in rows] + except Exception as e: + logger.warning(f"Failed to read table columns: {e}") + return [] + def _check_and_reload(self): """Check if the parquet file was modified externally and reload if so. @@ -190,6 +299,7 @@ def _create_table(self): "related_clusters": "VARCHAR", # JSON array "search_results": "VARCHAR", # JSON array "queries": "VARCHAR", # JSON array of historical queries + "merge_count": "INTEGER", # compile merge counter "embedding_vector": "FLOAT[384]", # 384-dim embedding vector "embedding_model": "VARCHAR", # Model identifier "embedding_timestamp": "TIMESTAMP", # Embedding computation time @@ -338,21 +448,22 @@ def _cluster_to_row(self, cluster: KnowledgeCluster) -> Dict[str, Any]: "related_clusters": json.dumps([rc.to_dict() for rc in cluster.related_clusters]), "search_results": json.dumps(cluster.search_results) if cluster.search_results else None, "queries": json.dumps(cluster.queries) if cluster.queries else None, + "merge_count": cluster.merge_count or 0, } def _row_to_cluster(self, row: tuple) -> KnowledgeCluster: """ Convert database row to KnowledgeCluster. - Expected row structure (24 columns): + Expected row structure (25 columns): id, name, description, content, scripts, resources, evidences, patterns, constraints, confidence, abstraction_level, landmark_potential, hotness, lifecycle, create_time, last_modified, version, related_clusters, search_results, queries, - embedding_vector, embedding_model, embedding_timestamp, embedding_text_hash + merge_count, embedding_vector, embedding_model, embedding_timestamp, embedding_text_hash """ - if len(row) != 24: + if len(row) != 25: raise ValueError( - f"Expected 24 columns in knowledge_clusters row, got {len(row)}. " + f"Expected 25 columns in knowledge_clusters row, got {len(row)}. " f"Please ensure the table schema is up to date." ) @@ -361,6 +472,7 @@ def _row_to_cluster(self, row: tuple) -> KnowledgeCluster: id, name, description, content, scripts, resources, evidences, patterns, constraints, confidence, abstraction_level, landmark_potential, hotness, lifecycle, create_time, last_modified, version, related_clusters, search_results, queries, + merge_count, _embedding_vector, _embedding_model, _embedding_timestamp, _embedding_text_hash ) = row @@ -400,7 +512,9 @@ def _row_to_cluster(self, row: tuple) -> KnowledgeCluster: is_found=ev_dict["is_found"], snippets=ev_dict["snippets"], extracted_at=extracted_at_parsed or datetime.now(), - conflict_group=ev_dict.get("conflict_group") + conflict_group=ev_dict.get("conflict_group"), + tree_path=ev_dict.get("tree_path"), + page_range=ev_dict.get("page_range"), )) # Parse constraints @@ -463,6 +577,7 @@ def _row_to_cluster(self, row: tuple) -> KnowledgeCluster: related_clusters=related_clusters_parsed, search_results=search_results_parsed, queries=queries_parsed, + merge_count=merge_count or 0, ) # ------------------------------------------------------------------ # diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py new file mode 100644 index 0000000..f115687 --- /dev/null +++ b/src/sirchmunk/utils/document_extractor.py @@ -0,0 +1,713 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Unified document extraction facade over kreuzberg. + +Centralizes all kreuzberg interaction into a single module, providing a clean, +configurable interface for document text extraction with support for tables, +metadata, language detection, OCR, and page-range filtering. + +All other modules should import from here rather than from kreuzberg directly. +""" + +from __future__ import annotations + +import asyncio +import dataclasses +import multiprocessing as mp +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, ClassVar, List, Optional, Sequence, Union + +from loguru import logger + + +# --------------------------------------------------------------------------- +# Subprocess extraction helpers (module-level for picklability) +# --------------------------------------------------------------------------- + +_EXTRACT_TIMEOUT_S = 600 + + +def _extraction_worker( + file_path: str, + profile_dict: dict[str, Any], + pipe_w: mp.connection.Connection, +) -> None: + """Child process entry point: run kreuzberg, send result via pipe, exit. + + Sends a plain dict so no native kreuzberg/Rust objects cross the + process boundary. On failure sends ``{"_error": ""}``. + """ + try: + import asyncio as _aio + + async def _run() -> dict[str, Any]: + from sirchmunk.utils.document_extractor import ( + DocumentExtractor, + ExtractionProfile, + ) + profile = ExtractionProfile(**profile_dict) + output = await DocumentExtractor.extract(file_path, profile) + return { + "content": output.content, + "mime_type": output.mime_type, + "metadata": output.metadata, + "tables": output.tables, + "detected_languages": output.detected_languages, + "page_count": output.page_count, + } + + pipe_w.send(_aio.run(_run())) + except BaseException as exc: + try: + pipe_w.send({"_error": str(exc)}) + except Exception: + pass + finally: + pipe_w.close() + + +def _run_extraction_in_child( + file_path: str, + profile_dict: dict[str, Any], +) -> dict[str, Any]: + """Spawn an isolated child process, wait for its result. + + Unlike ``ProcessPoolExecutor``, a crash in one child never + poisons future extractions — each call spawns a fresh process. + """ + pipe_r, pipe_w = mp.Pipe(duplex=False) + proc = mp.Process( + target=_extraction_worker, + args=(file_path, profile_dict, pipe_w), + daemon=True, + ) + proc.start() + pipe_w.close() + + try: + if not pipe_r.poll(timeout=_EXTRACT_TIMEOUT_S): + proc.kill() + proc.join(timeout=10) + raise RuntimeError( + f"Extraction timed out after {_EXTRACT_TIMEOUT_S}s" + ) + result = pipe_r.recv() + except EOFError: + proc.join(timeout=10) + raise RuntimeError( + f"Worker crashed (exit code {proc.exitcode})" + ) + finally: + pipe_r.close() + + proc.join(timeout=30) + if proc.is_alive(): + proc.kill() + proc.join() + + if isinstance(result, dict) and "_error" in result: + raise RuntimeError(result["_error"]) + return result + + +# --------------------------------------------------------------------------- +# Configuration profile +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class ExtractionProfile: + """Immutable extraction configuration profile. + + Controls which kreuzberg features are enabled during document extraction. + Default values align with the legacy ``fast_extract()`` behavior + (plain text only, no extras). + """ + + output_format: str = "plain" + """Output format: ``plain`` | ``markdown`` | ``html`` | ``djot``.""" + + extract_tables: bool = False + """Whether to extract and return tables.""" + + extract_metadata: bool = False + """Whether to return document metadata.""" + + detect_language: bool = False + """Whether to detect document language.""" + + ocr_enabled: bool = False + """Whether to enable OCR fallback.""" + + ocr_backend: str = "tesseract" + """OCR engine: ``tesseract`` | ``easyocr`` | ``paddleocr``.""" + + ocr_language: str = "eng" + """OCR language code (e.g. ``eng``, ``chi_sim``).""" + + page_start: Optional[int] = None + """Page range start (0-indexed). ``None`` means first page.""" + + page_end: Optional[int] = None + """Page range end (inclusive). ``None`` means last page.""" + + pdf_extract_images: bool = False + """Extract images embedded in PDF pages.""" + + pdf_extract_metadata: bool = False + """Extract PDF-level metadata (author, title, etc.).""" + + force_ocr: bool = False + """Force OCR for all pages, bypassing native text extraction. + + Maps directly to kreuzberg's ``ExtractionConfig.force_ocr``. + Note: kreuzberg does not offer a "fallback" OCR mode — + when set, OCR is always applied regardless of text layer presence. + """ + + force_ocr_pages: Optional[tuple[int, ...]] = None + """Force OCR on specific pages only (0-indexed). + + Maps to kreuzberg's ``ExtractionConfig.force_ocr_pages``. + Mutually exclusive with :attr:`force_ocr` — when both are set, + ``force_ocr`` takes precedence. + """ + + pdf_password: Optional[str] = None + """Password for encrypted PDFs.""" + + max_concurrent: Optional[int] = None + """Max concurrency for batch extraction.""" + + +# --------------------------------------------------------------------------- +# Extraction output +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class ExtractionOutput: + """Structured extraction result. + + Always contains ``content``. Other fields are populated based on the + :class:`ExtractionProfile` settings used during extraction. + """ + + content: str + """Extracted text content.""" + + mime_type: str = "" + """MIME type of the source document.""" + + metadata: dict[str, Any] = field(default_factory=dict) + """Document metadata (empty when ``extract_metadata`` is disabled).""" + + tables: list[dict[str, Any]] = field(default_factory=list) + """Extracted tables (empty when ``extract_tables`` is disabled).""" + + detected_languages: dict[str, float] = field(default_factory=dict) + """Language → confidence mapping (empty when ``detect_language`` is disabled).""" + + page_count: Optional[int] = None + """Number of pages in the source document (if available).""" + + +# --------------------------------------------------------------------------- +# Page-level extraction output +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class PageContent: + """Single page extraction result. + + Returned by :meth:`DocumentExtractor.extract_pages` to represent the + text content of one PDF page. + """ + + page_number: int + """1-indexed page number.""" + + content: str + """Extracted text content (may be empty string).""" + + +# --------------------------------------------------------------------------- +# Document extractor facade +# --------------------------------------------------------------------------- + +class DocumentExtractor: + """Unified document extraction facade over kreuzberg. + + Provides a clean, configurable interface for document text extraction, + centralizing all kreuzberg interaction within a single module. + + Usage:: + + # Basic extraction (identical to legacy fast_extract) + result = await DocumentExtractor.extract(path) + + # Enhanced extraction with tables and metadata + result = await DocumentExtractor.extract(path, DocumentExtractor.ENHANCED) + + # Custom profile + profile = ExtractionProfile(output_format="markdown", extract_tables=True) + result = await DocumentExtractor.extract(path, profile) + """ + + # Pre-defined profiles ------------------------------------------------- + + BASIC: ClassVar[ExtractionProfile] = ExtractionProfile() + """Plain-text extraction only — equivalent to legacy ``fast_extract()``.""" + + ENHANCED: ClassVar[ExtractionProfile] = ExtractionProfile( + output_format="markdown", + extract_tables=True, + extract_metadata=True, + pdf_extract_metadata=True, + force_ocr=False, + ) + """Rich extraction with tables, metadata, and layout-based table detection. + + ``force_ocr`` is disabled because: + - Most documents (e.g. 10-K, 10-Q PDFs) already contain a native text layer. + - kreuzberg automatically falls back to OCR for scanned / image-only pages. + - Forcing OCR triggers Tesseract ObjectCache leak warnings in concurrent use + and significantly slows down compilation with no quality benefit. + """ + + # Public API ----------------------------------------------------------- + + @staticmethod + async def extract( + file_path: Union[str, Path], + profile: Optional[ExtractionProfile] = None, + ) -> ExtractionOutput: + """Extract content from a single file. + + Args: + file_path: Path to the document. + profile: Extraction profile. Defaults to :attr:`BASIC`. + + Returns: + :class:`ExtractionOutput` with at least ``content`` populated. + + Raises: + FileNotFoundError: If *file_path* does not exist. + Exception: Propagates kreuzberg extraction errors after logging. + """ + from kreuzberg import extract_file + + profile = profile or DocumentExtractor.BASIC + config = DocumentExtractor._build_config(profile) + + try: + result = await extract_file(file_path=file_path, config=config) + output = DocumentExtractor._convert_result(result, profile) + # Fallback: kreuzberg 4.9.1 returns page_count=0 when force_ocr=True; + # use pypdf to get the real page count when missing. + if output.page_count is None: + fallback = DocumentExtractor._fallback_page_count(file_path) + if fallback is not None: + output = ExtractionOutput( + content=output.content, + mime_type=output.mime_type, + metadata=output.metadata, + tables=output.tables, + detected_languages=output.detected_languages, + page_count=fallback, + ) + return output + except Exception as exc: + logger.error( + "Document extraction failed for {}: {}", + file_path, + exc, + ) + raise + + @staticmethod + async def extract_isolated( + file_path: Union[str, Path], + profile: Optional[ExtractionProfile] = None, + ) -> ExtractionOutput: + """Extract content in a fully isolated child process. + + Each call spawns a fresh ``multiprocessing.Process``. When the + child exits (normally or via crash), the OS reclaims **all** of + its native memory — Rust arenas, layout-model buffers, image + caches — guaranteeing zero accumulation in the parent. + + Unlike ``ProcessPoolExecutor``, a crash in one extraction never + poisons future calls. + + Falls back to in-process extraction on subprocess failure. + """ + profile = profile or DocumentExtractor.BASIC + profile_dict = { + f.name: getattr(profile, f.name) + for f in dataclasses.fields(profile) + } + + loop = asyncio.get_event_loop() + try: + raw = await loop.run_in_executor( + None, + _run_extraction_in_child, + str(file_path), + profile_dict, + ) + return ExtractionOutput( + content=raw["content"], + mime_type=raw.get("mime_type", ""), + metadata=raw.get("metadata", {}), + tables=raw.get("tables", []), + detected_languages=raw.get("detected_languages", {}), + page_count=raw.get("page_count"), + ) + except Exception as exc: + logger.warning( + "Subprocess extraction failed for {}, falling back to in-process: {}", + file_path, exc, + ) + return await DocumentExtractor.extract(file_path, profile) + + @staticmethod + async def extract_bytes( + data: bytes, + mime_type: str, + profile: Optional[ExtractionProfile] = None, + ) -> ExtractionOutput: + """Extract content from raw bytes. + + Args: + data: File content as bytes. + mime_type: MIME type of the data (required for format detection). + profile: Extraction profile. Defaults to :attr:`BASIC`. + + Returns: + :class:`ExtractionOutput`. + """ + from kreuzberg import extract_bytes as _extract_bytes + + profile = profile or DocumentExtractor.BASIC + config = DocumentExtractor._build_config(profile) + + try: + result = await _extract_bytes(data=data, mime_type=mime_type, config=config) + return DocumentExtractor._convert_result(result, profile) + except Exception: + logger.error("Byte extraction failed for mime_type={}", mime_type) + raise + + @staticmethod + async def batch_extract( + file_paths: Sequence[Union[str, Path]], + profile: Optional[ExtractionProfile] = None, + ) -> List[ExtractionOutput]: + """Extract content from multiple files in parallel. + + Args: + file_paths: Sequence of document paths. + profile: Extraction profile. Defaults to :attr:`BASIC`. + + Returns: + List of :class:`ExtractionOutput`, one per input path. + """ + from kreuzberg import batch_extract_files + + profile = profile or DocumentExtractor.BASIC + config = DocumentExtractor._build_config(profile) + + try: + results = await batch_extract_files(paths=list(file_paths), config=config) + outputs = [ + DocumentExtractor._convert_result(r, profile) for r in results + ] + # Apply page_count fallback for each output + fixed: List[ExtractionOutput] = [] + for output, fp in zip(outputs, file_paths): + if output.page_count is None: + fallback = DocumentExtractor._fallback_page_count(fp) + if fallback is not None: + output = ExtractionOutput( + content=output.content, + mime_type=output.mime_type, + metadata=output.metadata, + tables=output.tables, + detected_languages=output.detected_languages, + page_count=fallback, + ) + fixed.append(output) + return fixed + except Exception: + logger.error("Batch extraction failed for {} files", len(file_paths)) + raise + + # Page-level extraction ------------------------------------------------- + + @staticmethod + def extract_pages( + file_path: Union[str, Path], + pages: list[int], + ) -> list[PageContent]: + """Extract text content from specific PDF pages. + + Uses pypdf to read individual pages by 1-indexed page number. + Invalid page numbers (< 1 or > total pages) are silently skipped. + + Args: + file_path: Path to a PDF file. + pages: List of 1-indexed page numbers to extract. + + Returns: + List of :class:`PageContent` for each valid requested page, + in the order given by *pages*. + + Raises: + FileNotFoundError: If *file_path* does not exist. + Exception: On PDF parsing failure (logged before re-raise). + """ + path = Path(file_path) + if not path.exists(): + raise FileNotFoundError(f"PDF file not found: {path}") + + try: + from pypdf import PdfReader + + reader = PdfReader(str(path)) + total = len(reader.pages) + valid_pages = [p for p in pages if 1 <= p <= total] + return [ + PageContent( + page_number=p, + content=reader.pages[p - 1].extract_text() or "", + ) + for p in valid_pages + ] + except FileNotFoundError: + raise + except Exception as exc: + logger.error( + "Page-level extraction failed for {}: {}", + file_path, + exc, + ) + raise + + @staticmethod + def extract_page_range( + file_path: Union[str, Path], + start_page: int, + end_page: int, + ) -> list[PageContent]: + """Extract text content from a contiguous range of PDF pages. + + Convenience wrapper around :meth:`extract_pages`. + + Args: + file_path: Path to a PDF file. + start_page: First page (1-indexed, inclusive). + end_page: Last page (1-indexed, inclusive). + + Returns: + List of :class:`PageContent` for the requested range. + """ + pages = list(range(start_page, end_page + 1)) + return DocumentExtractor.extract_pages(file_path, pages) + + # Internal helpers ----------------------------------------------------- + + @staticmethod + def _fallback_page_count( + file_path: Union[str, Path], + ) -> Optional[int]: + """Get page count via pypdf when kreuzberg fails to report it. + + kreuzberg >= 4.9.1 returns ``get_page_count() == 0`` when + ``force_ocr=True`` is set. This fallback uses pypdf (already a + transitive dependency) for a lightweight page-count-only read. + + Returns: + Page count, or None for non-PDF files or on error. + """ + if Path(file_path).suffix.lower() != ".pdf": + return None + try: + from pypdf import PdfReader + reader = PdfReader(str(file_path)) + count = len(reader.pages) + return count if count > 0 else None + except Exception: + return None + + @staticmethod + def _build_config(profile: ExtractionProfile): + """Build a kreuzberg ``ExtractionConfig`` from an :class:`ExtractionProfile`. + + Maps profile fields to the kreuzberg configuration objects that are + actually available in the installed version. + """ + from kreuzberg import ( + ExtractionConfig, + OcrConfig, + OutputFormat, + PageConfig, + PdfConfig, + ) + + # --- Output format --- + format_map = { + "plain": OutputFormat.PLAIN, + "markdown": OutputFormat.MARKDOWN, + "html": OutputFormat.HTML, + "djot": OutputFormat.DJOT, + } + output_format = format_map.get(profile.output_format, OutputFormat.PLAIN) + + # --- OCR config --- + ocr_config: Optional[OcrConfig] = None + if profile.ocr_enabled: + ocr_config = OcrConfig( + backend=profile.ocr_backend, + language=profile.ocr_language, + ) + + # --- Page config --- + page_config: Optional[PageConfig] = None + if profile.page_start is not None or profile.page_end is not None: + # kreuzberg PageConfig.extract_pages expects a list of page indices + pages: Optional[list[int]] = None + if profile.page_start is not None: + end = profile.page_end if profile.page_end is not None else profile.page_start + pages = list(range(profile.page_start, end + 1)) + page_config = PageConfig(extract_pages=pages) + + # --- PDF config --- + pdf_config: Optional[PdfConfig] = None + if any([ + profile.pdf_extract_images, + profile.pdf_extract_metadata, + profile.pdf_password, + ]): + passwords = [profile.pdf_password] if profile.pdf_password else None + pdf_config = PdfConfig( + extract_images=profile.pdf_extract_images, + extract_metadata=profile.pdf_extract_metadata, + passwords=passwords, + ) + + # --- Language detection --- + lang_config = None + if profile.detect_language: + from kreuzberg import LanguageDetectionConfig + lang_config = LanguageDetectionConfig(enabled=True) + + # --- Layout detection for table extraction --- + layout_config = None + if profile.extract_tables: + try: + from kreuzberg import LayoutDetectionConfig + layout_config = LayoutDetectionConfig( + confidence_threshold=0.3, + apply_heuristics=True, + table_model="slanet_auto", + ) + except ImportError: + pass + + # --- Assemble ExtractionConfig --- + kwargs: dict[str, Any] = { + "output_format": output_format, + } + if ocr_config is not None: + kwargs["ocr"] = ocr_config + if profile.force_ocr: + kwargs["force_ocr"] = True + elif profile.force_ocr_pages: + kwargs["force_ocr_pages"] = list(profile.force_ocr_pages) + if page_config is not None: + kwargs["pages"] = page_config + if pdf_config is not None: + kwargs["pdf_options"] = pdf_config + if lang_config is not None: + kwargs["language_detection"] = lang_config + if profile.max_concurrent is not None: + kwargs["max_concurrent_extractions"] = profile.max_concurrent + if layout_config is not None: + kwargs["layout"] = layout_config + + return ExtractionConfig(**kwargs) + + @staticmethod + def _convert_result( + result: "ExtractionResult", + profile: ExtractionProfile, + ) -> ExtractionOutput: + """Convert a kreuzberg ``ExtractionResult`` to :class:`ExtractionOutput`. + + Only populates optional fields when the corresponding profile flag is + enabled, keeping the output lean for basic extraction. + """ + content: str = result.content or "" + mime_type: str = getattr(result, "mime_type", "") or "" + + # Metadata + metadata: dict[str, Any] = {} + if profile.extract_metadata: + raw_meta = getattr(result, "metadata", None) + if raw_meta is not None: + if isinstance(raw_meta, dict): + metadata = dict(raw_meta) + else: + # kreuzberg may return a non-dict metadata object + try: + metadata = dict(raw_meta) + except (TypeError, ValueError): + metadata = {"raw": str(raw_meta)} + + # Tables + tables: list[dict[str, Any]] = [] + if profile.extract_tables: + raw_tables = getattr(result, "tables", None) or [] + for t in raw_tables: + if isinstance(t, dict): + tables.append(t) + else: + # kreuzberg ExtractedTable has: cells, markdown, page_number + tables.append({ + "markdown": getattr(t, "markdown", ""), + "cells": getattr(t, "cells", []), + "page_number": getattr(t, "page_number", None), + }) + + # Language detection + detected_languages: dict[str, float] = {} + if profile.detect_language: + raw_langs = getattr(result, "detected_languages", None) + if raw_langs: + for entry in raw_langs: + if isinstance(entry, dict): + lang = entry.get("language", "") + conf = entry.get("confidence", 0.0) + else: + # kreuzberg DetectedLanguage object + lang = getattr(entry, "language", "") + conf = getattr(entry, "confidence", 0.0) + if lang: + detected_languages[lang] = float(conf) + + # Page count — prefer get_page_count() over get_chunk_count() + page_count: Optional[int] = None + get_page_count = getattr(result, "get_page_count", None) + if get_page_count and callable(get_page_count): + cnt = get_page_count() + if cnt is not None and cnt > 0: + page_count = cnt + + return ExtractionOutput( + content=content, + mime_type=mime_type, + metadata=metadata, + tables=tables, + detected_languages=detected_languages, + page_count=page_count, + ) diff --git a/src/sirchmunk/utils/file_utils.py b/src/sirchmunk/utils/file_utils.py index edbbc2d..df308fd 100644 --- a/src/sirchmunk/utils/file_utils.py +++ b/src/sirchmunk/utils/file_utils.py @@ -4,17 +4,29 @@ from pathlib import Path from typing import Union -from kreuzberg import ExtractionResult, extract_file from loguru import logger +from sirchmunk.utils.document_extractor import ( + DocumentExtractor, + ExtractionOutput, +) -async def fast_extract(file_path: Union[str, Path]) -> ExtractionResult: - """ - Automatically detects and extracts text content from various file formats like docx, pptx, pdf, xlsx. - """ - result: ExtractionResult = await extract_file(file_path=file_path) - return result +async def fast_extract(file_path: Union[str, Path]) -> ExtractionOutput: + """Extract text content from a file using kreuzberg. + + This is a backward-compatible wrapper around + :meth:`DocumentExtractor.extract` with the ``BASIC`` profile + (plain text, no extras). All callers that only need ``.content`` + continue to work unchanged. + + Args: + file_path: Path to the file to extract. + + Returns: + :class:`ExtractionOutput` with ``.content`` populated. + """ + return await DocumentExtractor.extract(file_path) def get_fast_hash(file_path: Union[str, Path], sample_size: int = 8192): diff --git a/src/sirchmunk/version.py b/src/sirchmunk/version.py index 87b826d..73c89eb 100644 --- a/src/sirchmunk/version.py +++ b/src/sirchmunk/version.py @@ -1 +1 @@ -__version__ = "0.0.7+main" +__version__ = "0.0.8+main"