Source code for arborist.sources.wikipedia

"""MediaWiki 'cur' table SQL dump source.

Handles 2003-era SQL dumps in bz2 format (e.g. 20030516_cur_tablesql.bz2).
Yields one Document per non-redirect main-namespace article, with
[[wikilinks]] extracted as outbound edges.

Implements a stream parser for MySQL extended INSERT syntax. The cur table
schema for that era starts: cur_id, cur_namespace, cur_title, cur_text, ...
We rely on positional access for the first four columns.
"""

from __future__ import annotations

import bz2
import re
from pathlib import Path
from typing import IO, Iterator

from arborist.document import Document, Edge
from arborist.source import Source


# Match [[Target]], [[Target|display]], [[Target#anchor]] forms.
WIKILINK_RE = re.compile(r"\[\[([^\]\|#\n\r]+)(?:#[^\]\|\n\r]*)?(?:\|[^\]\n\r]*)?\]\]")
NAMESPACE_MAIN = 0


_MYSQL_ESCAPE_MAP = {
    "n": "\n",
    "r": "\r",
    "t": "\t",
    "0": "\0",
    "\\": "\\",
    "'": "'",
    '"': '"',
    "Z": "\x1a",
}


def _decode_mysql_string(s: str) -> str:
    """Decode MySQL-escaped string content (outer quotes already stripped).

    Hot fast path: the vast majority of string fields contain no backslashes
    at all. Returning the input unchanged for those skips the slow loop.
    """
    if "\\" not in s:
        return s
    out: list[str] = []
    i = 0
    n = len(s)
    # Cache the mapping locally for tight-loop dispatch.
    while i < n:
        # Jump forward to the next backslash with C-level find.
        bs = s.find("\\", i)
        if bs < 0:
            out.append(s[i:])
            break
        if bs > i:
            out.append(s[i:bs])
        if bs + 1 >= n:
            out.append("\\")
            i = bs + 1
            continue
        nxt = s[bs + 1]
        out.append(_MYSQL_ESCAPE_MAP.get(nxt, nxt))
        i = bs + 2
    return "".join(out)


def _find_string_end(payload: str, start: int) -> int:
    """Return position of the closing `'` for a string opened just before `start`.

    Backslash escapes considered: even count of preceding backslashes => the
    quote is a real terminator; odd count => the quote is escaped.
    """
    n = len(payload)
    i = start
    while i < n:
        pos = payload.find("'", i)
        if pos < 0:
            return -1
        bs = 0
        p = pos - 1
        while p >= start and payload[p] == "\\":
            bs += 1
            p -= 1
        if bs % 2 == 0:
            return pos
        i = pos + 1
    return -1


def _split_values_tuples(payload: str) -> list[list[str | None]]:
    """Parse `(v1,v2,...),(...),...` into tuples of decoded strings or None.

    String slicing + str.find based. No per-char list.append. Roughly 5-10x
    faster than the char-by-char version on real Wikipedia dumps.
    """
    rows: list[list[str | None]] = []
    i = 0
    n = len(payload)
    while i < n:
        # Find the next '(' starting a tuple.
        open_paren = payload.find("(", i)
        if open_paren < 0:
            break
        i = open_paren + 1
        values: list[str | None] = []
        while i < n:
            # Skip leading whitespace inside the tuple.
            while i < n and payload[i] in " \t":
                i += 1
            if i >= n:
                break
            if payload[i] == "'":
                end = _find_string_end(payload, i + 1)
                if end < 0:
                    return rows  # malformed; bail
                values.append(_decode_mysql_string(payload[i + 1 : end]))
                i = end + 1
            else:
                # NULL or unquoted scalar; ends at , or )
                end = i
                while end < n and payload[end] not in ",)":
                    end += 1
                v = payload[i:end].strip()
                values.append(None if v.upper() == "NULL" else v)
                i = end
            while i < n and payload[i] in " \t":
                i += 1
            if i < n and payload[i] == ",":
                i += 1
                continue
            if i < n and payload[i] == ")":
                i += 1
                break
            break
        rows.append(values)
    return rows


def _iter_insert_statements(file_obj: IO[str]) -> Iterator[str]:
    """Yield complete SQL statements using buffered string find.

    State persists across read() chunks via three indices:
      yield_from  position in `pending` where the next yielded statement starts
      scan_pos    position to resume the scanner from (do NOT restart at 0
                  across chunks — would re-find the already-processed open `'`)
      in_string   are we inside a single-quoted string

    The previous version reset the scanner to position 0 on every chunk,
    which silently flipped in_string=False as soon as the first `'` in the
    new buffer was re-encountered (the title-field open quote of the
    article we were already deep inside). Persisting scan_pos fixes it.
    """
    pending = ""
    in_string = False
    yield_from = 0  # start of current statement, in `pending` coords
    scan_pos = 0    # next position to scan, in `pending` coords

    while True:
        chunk = file_obj.read(1 << 19)
        if not chunk:
            break
        pending += chunk
        n = len(pending)

        i = scan_pos
        while i < n:
            if in_string:
                # Find string terminator (handle escape parity).
                end = i
                while True:
                    pos = pending.find("'", end)
                    if pos < 0:
                        i = n  # need more data
                        break
                    bs = 0
                    # Don't walk back past the start of the current statement;
                    # chars before yield_from belong to a previous statement.
                    p = pos - 1
                    while p >= yield_from and pending[p] == "\\":
                        bs += 1
                        p -= 1
                    if bs % 2 == 0:
                        in_string = False
                        i = pos + 1
                        break
                    end = pos + 1
                if in_string:
                    break
                continue
            sc = pending.find(";", i)
            qt = pending.find("'", i)
            if sc < 0 and qt < 0:
                i = n
                break
            if qt < 0 or (sc >= 0 and sc < qt):
                yield pending[yield_from : sc + 1]
                yield_from = sc + 1
                i = sc + 1
            else:
                in_string = True
                i = qt + 1

        scan_pos = i
        # Release yielded bytes; remap our two indices into the new buffer.
        if yield_from > 0:
            pending = pending[yield_from:]
            scan_pos -= yield_from
            yield_from = 0

    tail = pending[yield_from:]
    if tail.strip():
        yield tail


def _extract_wikilinks(text: str, base_uri: str) -> list[Edge]:
    seen: set[str] = set()
    edges: list[Edge] = []
    for m in WIKILINK_RE.finditer(text):
        target = m.group(1).strip()
        if not target or target.startswith(":"):
            continue
        # Skip image / file / category interlinks (they often start "Image:" etc.)
        if ":" in target:
            continue
        uri = base_uri + target.replace(" ", "_")
        if uri in seen:
            continue
        seen.add(uri)
        edges.append(Edge(edge_type="wikilink", dst_uri=uri))
    return edges


[docs] class WikipediaSqlDump(Source): """Iterates a MediaWiki SQL table dump (cur or old), bz2 or plain. Both `cur` (current snapshot) and `old` (revision history) tables share the first four column positions: id, namespace, title, text. The `cur` table has `cur_is_redirect` at position 10 (we skip redirects); `old` has no redirect flag (every revision is real). Shard support: pass `shard=(rank, total)` and the source yields only docs whose 0-based index satisfies `index % total == rank`. Useful for spawning N parallel ingest processes against the same dump file — parser CPU runs in parallel, writes serialize at the WAL writer lock. """ def __init__( self, path: str | Path, *, table: str = "cur", namespace: int = NAMESPACE_MAIN, base_uri: str = "https://en.wikipedia.org/wiki/", shard: tuple[int, int] | None = None, start_id: int = 0, encoding: str = "latin-1", ): if table not in ("cur", "old"): raise ValueError("table must be 'cur' or 'old'") self.path = Path(path) self.table = table self.namespace = namespace self.base_uri = base_uri self.source_type = f"wikipedia_{table}" if shard is not None: rank, total = shard if not (0 <= rank < total) or total < 1: raise ValueError( f"invalid shard {shard}: need 0 <= rank < total, total >= 1" ) self.shard = shard # Resume support: skip rows whose id (cur_id or old_id) is <= start_id. # The rsync-style fast-forward — already-cached docs aren't re-hashed. self.start_id = start_id # High-water mark observed during iteration. Caller reads this back # after each batch and persists it via store.set_meta(). self.last_id: int = start_id # 2003-era MediaWiki dumps mix Latin-1 raw bytes (e.g., 0xE9 for é) # with HTML entities. Latin-1 decode is lossless on every byte and # gives the right code point for the raw-byte cases. Modern dumps # are UTF-8 — pass encoding="utf-8" for those. self.encoding = encoding
[docs] def iter_documents(self) -> Iterator[Document]: opener = bz2.open if str(self.path).endswith(".bz2") else open rank, total = (0, 1) if self.shard is None else self.shard idx = 0 with opener(self.path, "rt", encoding=self.encoding, errors="replace") as f: for stmt in _iter_insert_statements(f): head = stmt.lstrip() if not head.upper().startswith("INSERT INTO"): continue up = head.upper() vidx = up.find("VALUES") if vidx < 0: continue table_clause = head[:vidx].lower() token = f" {self.table} " if token not in (table_clause + " "): continue payload = head[vidx + len("VALUES"):] payload = payload.rstrip().rstrip(";").rstrip() for row in _split_values_tuples(payload): # Stride filter applied at the row index — every shard sees # the full SQL stream but only emits its share. if idx % total != rank: idx += 1 continue idx += 1 # Resume fast-forward: skip rows already cached. We still # parse the value tuples (needed to find row[0]) but skip # the wikilink extraction + Document construction, which # is the dominant per-row cost. if self.start_id: try: row_id = int(row[0]) if row[0] is not None else 0 except (ValueError, TypeError): row_id = 0 if row_id <= self.start_id: continue if row_id > self.last_id: self.last_id = row_id else: try: row_id = int(row[0]) if row[0] is not None else 0 if row_id > self.last_id: self.last_id = row_id except (ValueError, TypeError): pass yield from self._row_to_doc(row)
def _row_to_doc(self, row: list[str | None]) -> Iterator[Document]: if len(row) < 4: return try: ns = int(row[1]) if row[1] is not None else None except (ValueError, TypeError): return if ns != self.namespace: return title = row[2] or "" text = row[3] or "" if not title or not text: return # Cur-only: skip rows flagged as redirects (col 10 in the 2003 schema). if self.table == "cur": is_redirect = False if len(row) > 10 and row[10] is not None: try: is_redirect = bool(int(row[10])) except (ValueError, TypeError): is_redirect = False if is_redirect: return edges = _extract_wikilinks(text, self.base_uri) uri = self.base_uri + title.replace(" ", "_") extra: dict = {} # Surface the row id for resume / chronological ordering. if row[0] is not None: extra[f"{self.table}_id"] = row[0] # `old` rows also carry a timestamp at position 7. if self.table == "old" and len(row) > 7 and row[7]: extra["old_timestamp"] = row[7] yield Document( uri=uri, content=text, source_type=self.source_type, title=title, edges=edges, extra=extra, )
# Backward-compatible thin wrapper.
[docs] class WikipediaCurDump(WikipediaSqlDump): """Iterates a MediaWiki 'cur' table SQL dump.""" def __init__( self, path: str | Path, namespace: int = NAMESPACE_MAIN, base_uri: str = "https://en.wikipedia.org/wiki/", ): super().__init__( path, table="cur", namespace=namespace, base_uri=base_uri )
[docs] class WikipediaOldDump(WikipediaSqlDump): """Iterates a MediaWiki 'old' (revision history) table SQL dump.""" def __init__( self, path: str | Path, namespace: int = NAMESPACE_MAIN, base_uri: str = "https://en.wikipedia.org/wiki/", ): super().__init__( path, table="old", namespace=namespace, base_uri=base_uri )