Source code for arborist.search.fts5

"""SQLite FTS5 keyword search. Returns UNGROUNDED-mode hits (no proof claim).

Snippet generation runs in Python because chunks_fts is contentless
(`content=''`) — SQLite's snippet()/highlight() functions return empty
strings under contentless mode. We JOIN the FTS5 hit rowid back to
`chunks.chunk_id` to fetch and decompress the original chunk content,
then locate query tokens locally.
"""

from __future__ import annotations

import re

from arborist.compress import unpack_chunk
from arborist.search.base import AuditMode, Hit, SearchBackend


_FTS5_TOKEN_RE = re.compile(r"[A-Za-z][A-Za-z0-9]*")
_FTS5_STOPWORDS = frozenset(
    """
    the a an is are was were be been being
    of to in on at for with by from as about into through during
    and or but not no nor so yet too very also just
    what who where when why how which this that these those such
    i you he she it we they me him her us them
    do does did have has had can could should would will may might
    tell show describe explain summarize say give list find make
    please
    all there know everything anything something

    one some another without soon currently
    """.split()
)


# OR-mode fallback hits SQLite FTS5 with one clause per token. With 15+
# tokens including common English connectors it matches millions of
# docs, then BM25 has to rank them all to pick the top-K — 13s/shard
# observed on a 3.47M-doc corpus. Long queries that fail AND-mode
# should not pay for an OR-mode that's effectively a full-corpus scan.
# Cap OR-mode at the top-N LONGEST tokens (proxy for rarity / topical
# specificity — "neurotechnology" matters; "soon" doesn't).
_OR_FALLBACK_MAX_TOKENS = 5

# Per-token document-frequency cap at OR-fallback time. A token whose
# corpus DF exceeds this threshold (e.g. "located" matching 286k chunks
# on a 1.5M-chunk wiki shard) contributes ~zero IDF to BM25 yet
# dominates the candidate set the engine has to score. Drop it from the
# OR pool before MATCH. Cost: one ``COUNT(MATCH "tok")`` per candidate
# (~15ms warm). Backstops the progressive-AND fallback in case AND
# returns zero on every chain — that path lands in OR-mode where this
# filter prevents a single high-DF token from blowing the wall.
_OR_FALLBACK_MAX_TOKEN_DF = 50_000


def _progressive_and_token_chains(tokens: list[str]) -> list[list[str]]:
    """Yield successively-narrower AND token sets for progressive fallback.

    First chain is the full token list. Each subsequent chain drops one
    additional token, picked SHORTEST-FIRST with input order as the
    tie-break. Stops when only one token remains.

    Why shortest-first: short tokens are usually high-DF verbs/connectors
    that collapse the AND set to zero ("located" matches 286k chunks on
    a 1.5M-chunk wiki shard but contributes ~zero topical signal). The
    rare topical token ("Gundremmingen") is what keeps AND tight; we
    drop the cheap-signal tokens first so AND stays narrow and BM25
    isn't forced to rank a full-corpus union via OR-fallback.

    Concrete example, query tokens ["Gundremmingen", "located", "Bavaria"]:

        chain 0: Gundremmingen AND located AND Bavaria   (zero hits)
        chain 1: Gundremmingen AND Bavaria               ("located" dropped)
        chain 2: Gundremmingen                           ("Bavaria" dropped)

    The first chain that returns at least one row wins; OR-fallback only
    fires if every progressive-AND chain returns zero.
    """
    if len(tokens) <= 1:
        return [list(tokens)]
    # Decorate-sort-undecorate: sort by (length-asc, original-position-asc),
    # then drop one at a time from the front of the sort.
    indexed = sorted(enumerate(tokens), key=lambda iv: (len(iv[1]), iv[0]))
    drop_order = [i for i, _ in indexed]
    chains: list[list[str]] = []
    dropped: set[int] = set()
    chains.append(list(tokens))
    for drop_idx in drop_order[:-1]:  # always keep at least one token
        dropped.add(drop_idx)
        chains.append([t for i, t in enumerate(tokens) if i not in dropped])
    return chains


def _query_tokens(query: str) -> list[str]:
    raw = _FTS5_TOKEN_RE.findall(query)
    return [t for t in raw if t.lower() not in _FTS5_STOPWORDS and len(t) > 1]


def _quote(t: str) -> str:
    return '"' + t.replace('"', '""') + '"'


def _filter_or_pool_by_df(
    conn,
    tokens: list[str],
    *,
    threshold: int | None = None,
) -> list[str]:
    """Keep only tokens whose corpus document-frequency is <= ``threshold``.

    Probe via ``COUNT(*) FROM chunks_fts WHERE chunks_fts MATCH '"tok"'``
    — same FTS5 path the real search would take, so the cost mirrors a
    real lookup (~15ms warm, ~50ms cold per token on a 10GB shard).

    ``threshold=None`` resolves to the module global at call time so a
    test can monkeypatch ``_OR_FALLBACK_MAX_TOKEN_DF`` and have the
    helper see the new value (function-default args bind at def time,
    which would freeze the constant at import).

    If the filter would empty the pool, fall back to returning the input
    unchanged: OR-fallback is the last-resort retrieval path, and "some
    hits, slow" is still strictly better than "no hits". Quietly swallow
    malformed-MATCH SQL errors (rare; tokenizer-strange chars) — those
    would have been dropped by the search anyway.
    """
    if threshold is None:
        threshold = _OR_FALLBACK_MAX_TOKEN_DF
    keep: list[str] = []
    for t in tokens:
        try:
            row = conn.execute(
                "SELECT COUNT(*) FROM chunks_fts WHERE chunks_fts MATCH ?",
                (_quote(t),),
            ).fetchone()
        except Exception:
            continue
        n = row[0] if row else 0
        if 0 < n <= threshold:
            keep.append(t)
    return keep or list(tokens)


def _escape_fts5(
    query: str,
    *,
    mode: str = "and",
    extra_or_tokens: set[str] | None = None,
) -> str:
    """Build a MATCH expression from a free-text query.

    Tokenizes by alpha runs (so `?` and other punctuation can't break the
    quoted phrase), drops question stopwords ('what', 'tell', etc.), then:

    - mode='and' (default): every content token must appear in the doc.
      Strict relevance — keeps unrelated docs out of the context window.
    - mode='or':  the top-N LONGEST tokens (proxy for topical specificity)
      get OR-joined. Used as a fallback when AND returns zero hits.
      Capped at ``_OR_FALLBACK_MAX_TOKENS`` so OR-mode doesn't degrade
      into a full-corpus scan on long noisy queries (a 19-token OR
      clause matches millions of docs and forces BM25 to rank them all
      — 13s/shard observed pre-cap).

    The ``extra_or_tokens`` kwarg accepts synonym-expanded tokens
    (e.g. "telepathy" / "neurotechnology" expanded from the query token
    "thoughts"). These join the top-N-longest OR pool — long synonym
    tokens like "neurotechnology" (15 chars) outrank short query tokens
    like "thoughts" (8) by length and surface the right titles. Pure
    quality win at OR-mode-fallback time without extra retrieval cost
    since the top-N cap still applies to the merged pool.

    All-stopword queries fall back to OR over the raw tokens so they
    still find something instead of crashing FTS5 with an empty MATCH.
    """
    tokens = _query_tokens(query)
    if not tokens:
        raw = _FTS5_TOKEN_RE.findall(query)
        tokens = [t for t in raw if len(t) > 1] or ['""']
        sep = " OR "
    elif mode == "and":
        sep = " AND "
    else:
        # OR fallback: keep only the rarest tokens (approximated by
        # longest — long words tend to be more topical / rarer in the
        # corpus). Bounds the per-clause cost so OR-mode terminates
        # quickly instead of scanning the corpus.
        sep = " OR "
        candidate_pool = list(tokens)
        if extra_or_tokens:
            # Synonyms join the OR pool; dedupe lowercase.
            seen_lower = {t.lower() for t in candidate_pool}
            for t in extra_or_tokens:
                tl = t.lower()
                if tl and tl not in seen_lower and tl not in _FTS5_STOPWORDS:
                    candidate_pool.append(tl)
                    seen_lower.add(tl)
        if len(candidate_pool) > _OR_FALLBACK_MAX_TOKENS:
            candidate_pool = sorted(candidate_pool, key=len, reverse=True)[
                :_OR_FALLBACK_MAX_TOKENS
            ]
        tokens = candidate_pool
    return sep.join(_quote(t) for t in tokens)


# Snippet rendering. Locate any query token (case-insensitive) in the chunk
# text, return ~16 words of surrounding context with the matched token
# bracketed. If no token matches (rare: query was all-stopwords or the
# tokens only appear in titles), fall back to the chunk's leading slice.
_SNIPPET_WINDOW_WORDS = 16


def _build_snippet(text: str, query: str) -> str:
    if not text:
        return ""
    tokens = _query_tokens(query)
    if not tokens:
        # Best-effort: leading slice.
        words = text.split()
        return " ".join(words[: _SNIPPET_WINDOW_WORDS * 2])

    # Find the earliest case-insensitive match of any query token.
    text_lower = text.lower()
    best_pos = -1
    best_token = ""
    for t in tokens:
        pos = text_lower.find(t.lower())
        if pos >= 0 and (best_pos < 0 or pos < best_pos):
            best_pos = pos
            best_token = t
    if best_pos < 0:
        words = text.split()
        return " ".join(words[: _SNIPPET_WINDOW_WORDS * 2])

    # Walk word boundaries around the match position.
    words = text.split()
    if not words:
        return ""
    # Map character position to word index (approximate — split() collapses
    # runs of whitespace; close enough for visual snippet purposes).
    char_count = 0
    target_word = 0
    for i, w in enumerate(words):
        char_count += len(w) + 1  # +1 for the join space
        if char_count > best_pos:
            target_word = i
            break

    start = max(0, target_word - _SNIPPET_WINDOW_WORDS)
    end = min(len(words), target_word + _SNIPPET_WINDOW_WORDS)
    window = words[start:end]

    # Bracket every case-insensitive occurrence of every matched token in
    # the window. Done with a precompiled regex for each token.
    rendered = " ".join(window)
    for t in tokens:
        pat = re.compile(re.escape(t), re.IGNORECASE)
        rendered = pat.sub(lambda m: f"[{m.group(0)}]", rendered)

    if start > 0:
        rendered = "…" + rendered
    if end < len(words):
        rendered = rendered + "…"
    return rendered


[docs] class FTS5Backend(SearchBackend): name = "fts5" audit_mode = AuditMode.UNGROUNDED
[docs] def search( self, query: str, limit: int = 20, extra_or_tokens: set[str] | None = None, ) -> list[Hit]: """Run FTS5 BM25 over chunk content. ``extra_or_tokens`` (synonym-expanded set) is passed through to the OR-mode fallback only. AND mode stays on original query tokens (adding synonyms there would relax the AND constraint and pull in noise). The intended caller is the retrieval pipeline that has already computed ``synonym_expand(qtokens)`` — passing it here saves the OR-mode pool from missing topical synonym terms. """ if not query.strip(): return [] # Progressive-AND fallback before OR. AND-mode is signal-dense # but brittle: a single high-DF token in the query (e.g. # "located") can collapse the intersection to zero even though # the rest of the tokens uniquely identify the topic. Old # behavior fell straight to OR-mode, which on a 1.5M-chunk # shard pulls 286k matches when "located" is one of the OR # clauses and forces BM25 to rank them all (~27s cold I/O). # Progressive-AND drops shortest-first and retries before # surrendering to OR. rows: list = [] and_tokens = _query_tokens(query) for chain in _progressive_and_token_chains(and_tokens): if not chain: continue fts_query = " AND ".join(_quote(t) for t in chain) try: rows = self.conn.execute( """ SELECT c.document_root, c.idx, c.content AS raw_content, bm25(chunks_fts) AS rank, d.document_uri, d.title FROM chunks_fts AS f JOIN chunks AS c ON c.chunk_id = f.rowid JOIN documents AS d ON d.document_root = c.document_root WHERE chunks_fts MATCH ? ORDER BY rank ASC LIMIT ? """, (fts_query, limit), ).fetchall() except Exception: rows = [] if rows: break # OR-mode fallback. Folds in synonym OR-pool, then drops any # token whose corpus DF exceeds ``_OR_FALLBACK_MAX_TOKEN_DF`` — # backstops progressive-AND when every chain still returned # zero. Without this filter a single high-DF query token # (e.g. "located": 286k matches on a 1.5M-chunk shard) would # force BM25 to score ~all of those matches just to pick the # top-K. if not rows: or_pool = _query_tokens(query) if not or_pool: # All-stopword query: keep raw single-char-survivor # tokens (current _escape_fts5 behavior). raw = _FTS5_TOKEN_RE.findall(query) or_pool = [t for t in raw if len(t) > 1] or ['""'] else: # Synonym merge — same dedup as the legacy _escape_fts5 # OR branch. if extra_or_tokens: seen_lower = {t.lower() for t in or_pool} for t in extra_or_tokens: tl = t.lower() if tl and tl not in seen_lower and tl not in _FTS5_STOPWORDS: or_pool.append(tl) seen_lower.add(tl) # B: DF filter. ~15ms warm per token; only fires on the # rare path where every progressive-AND chain returned # zero, so the cost is amortized over the entire query. or_pool = _filter_or_pool_by_df(self.conn, or_pool) # Top-N-longest cap (proxy for rarity / topical specificity). if len(or_pool) > _OR_FALLBACK_MAX_TOKENS: or_pool = sorted(or_pool, key=len, reverse=True)[ :_OR_FALLBACK_MAX_TOKENS ] fts_query = " OR ".join(_quote(t) for t in or_pool) try: rows = self.conn.execute( """ SELECT c.document_root, c.idx, c.content AS raw_content, bm25(chunks_fts) AS rank, d.document_uri, d.title FROM chunks_fts AS f JOIN chunks AS c ON c.chunk_id = f.rowid JOIN documents AS d ON d.document_root = c.document_root WHERE chunks_fts MATCH ? ORDER BY rank ASC LIMIT ? """, (fts_query, limit), ).fetchall() except Exception: rows = [] return [ Hit( document_root=r["document_root"], document_uri=r["document_uri"], chunk_idx=r["idx"], snippet=_build_snippet(unpack_chunk(r["raw_content"]) or "", query), # bm25 returns negative numbers (lower = better); flip sign. score=-float(r["rank"]) if r["rank"] is not None else 0.0, audit_mode=self.audit_mode, title=r["title"], ) for r in rows ]