"""MediaWiki 'cur' table SQL dump source.
Handles 2003-era SQL dumps in bz2 format (e.g. 20030516_cur_tablesql.bz2).
Yields one Document per non-redirect main-namespace article, with
[[wikilinks]] extracted as outbound edges.
Implements a stream parser for MySQL extended INSERT syntax. The cur table
schema for that era starts: cur_id, cur_namespace, cur_title, cur_text, ...
We rely on positional access for the first four columns.
"""
from __future__ import annotations
import bz2
import re
from pathlib import Path
from typing import IO, Iterator
from arborist.document import Document, Edge
from arborist.source import Source
# Match [[Target]], [[Target|display]], [[Target#anchor]] forms.
WIKILINK_RE = re.compile(r"\[\[([^\]\|#\n\r]+)(?:#[^\]\|\n\r]*)?(?:\|[^\]\n\r]*)?\]\]")
NAMESPACE_MAIN = 0
_MYSQL_ESCAPE_MAP = {
"n": "\n",
"r": "\r",
"t": "\t",
"0": "\0",
"\\": "\\",
"'": "'",
'"': '"',
"Z": "\x1a",
}
def _decode_mysql_string(s: str) -> str:
"""Decode MySQL-escaped string content (outer quotes already stripped).
Hot fast path: the vast majority of string fields contain no backslashes
at all. Returning the input unchanged for those skips the slow loop.
"""
if "\\" not in s:
return s
out: list[str] = []
i = 0
n = len(s)
# Cache the mapping locally for tight-loop dispatch.
while i < n:
# Jump forward to the next backslash with C-level find.
bs = s.find("\\", i)
if bs < 0:
out.append(s[i:])
break
if bs > i:
out.append(s[i:bs])
if bs + 1 >= n:
out.append("\\")
i = bs + 1
continue
nxt = s[bs + 1]
out.append(_MYSQL_ESCAPE_MAP.get(nxt, nxt))
i = bs + 2
return "".join(out)
def _find_string_end(payload: str, start: int) -> int:
"""Return position of the closing `'` for a string opened just before `start`.
Backslash escapes considered: even count of preceding backslashes => the
quote is a real terminator; odd count => the quote is escaped.
"""
n = len(payload)
i = start
while i < n:
pos = payload.find("'", i)
if pos < 0:
return -1
bs = 0
p = pos - 1
while p >= start and payload[p] == "\\":
bs += 1
p -= 1
if bs % 2 == 0:
return pos
i = pos + 1
return -1
def _split_values_tuples(payload: str) -> list[list[str | None]]:
"""Parse `(v1,v2,...),(...),...` into tuples of decoded strings or None.
String slicing + str.find based. No per-char list.append. Roughly 5-10x
faster than the char-by-char version on real Wikipedia dumps.
"""
rows: list[list[str | None]] = []
i = 0
n = len(payload)
while i < n:
# Find the next '(' starting a tuple.
open_paren = payload.find("(", i)
if open_paren < 0:
break
i = open_paren + 1
values: list[str | None] = []
while i < n:
# Skip leading whitespace inside the tuple.
while i < n and payload[i] in " \t":
i += 1
if i >= n:
break
if payload[i] == "'":
end = _find_string_end(payload, i + 1)
if end < 0:
return rows # malformed; bail
values.append(_decode_mysql_string(payload[i + 1 : end]))
i = end + 1
else:
# NULL or unquoted scalar; ends at , or )
end = i
while end < n and payload[end] not in ",)":
end += 1
v = payload[i:end].strip()
values.append(None if v.upper() == "NULL" else v)
i = end
while i < n and payload[i] in " \t":
i += 1
if i < n and payload[i] == ",":
i += 1
continue
if i < n and payload[i] == ")":
i += 1
break
break
rows.append(values)
return rows
def _iter_insert_statements(file_obj: IO[str]) -> Iterator[str]:
"""Yield complete SQL statements using buffered string find.
State persists across read() chunks via three indices:
yield_from position in `pending` where the next yielded statement starts
scan_pos position to resume the scanner from (do NOT restart at 0
across chunks — would re-find the already-processed open `'`)
in_string are we inside a single-quoted string
The previous version reset the scanner to position 0 on every chunk,
which silently flipped in_string=False as soon as the first `'` in the
new buffer was re-encountered (the title-field open quote of the
article we were already deep inside). Persisting scan_pos fixes it.
"""
pending = ""
in_string = False
yield_from = 0 # start of current statement, in `pending` coords
scan_pos = 0 # next position to scan, in `pending` coords
while True:
chunk = file_obj.read(1 << 19)
if not chunk:
break
pending += chunk
n = len(pending)
i = scan_pos
while i < n:
if in_string:
# Find string terminator (handle escape parity).
end = i
while True:
pos = pending.find("'", end)
if pos < 0:
i = n # need more data
break
bs = 0
# Don't walk back past the start of the current statement;
# chars before yield_from belong to a previous statement.
p = pos - 1
while p >= yield_from and pending[p] == "\\":
bs += 1
p -= 1
if bs % 2 == 0:
in_string = False
i = pos + 1
break
end = pos + 1
if in_string:
break
continue
sc = pending.find(";", i)
qt = pending.find("'", i)
if sc < 0 and qt < 0:
i = n
break
if qt < 0 or (sc >= 0 and sc < qt):
yield pending[yield_from : sc + 1]
yield_from = sc + 1
i = sc + 1
else:
in_string = True
i = qt + 1
scan_pos = i
# Release yielded bytes; remap our two indices into the new buffer.
if yield_from > 0:
pending = pending[yield_from:]
scan_pos -= yield_from
yield_from = 0
tail = pending[yield_from:]
if tail.strip():
yield tail
def _extract_wikilinks(text: str, base_uri: str) -> list[Edge]:
seen: set[str] = set()
edges: list[Edge] = []
for m in WIKILINK_RE.finditer(text):
target = m.group(1).strip()
if not target or target.startswith(":"):
continue
# Skip image / file / category interlinks (they often start "Image:" etc.)
if ":" in target:
continue
uri = base_uri + target.replace(" ", "_")
if uri in seen:
continue
seen.add(uri)
edges.append(Edge(edge_type="wikilink", dst_uri=uri))
return edges
[docs]
class WikipediaSqlDump(Source):
"""Iterates a MediaWiki SQL table dump (cur or old), bz2 or plain.
Both `cur` (current snapshot) and `old` (revision history) tables share
the first four column positions: id, namespace, title, text. The `cur`
table has `cur_is_redirect` at position 10 (we skip redirects); `old`
has no redirect flag (every revision is real).
Shard support: pass `shard=(rank, total)` and the source yields only
docs whose 0-based index satisfies `index % total == rank`. Useful for
spawning N parallel ingest processes against the same dump file —
parser CPU runs in parallel, writes serialize at the WAL writer lock.
"""
def __init__(
self,
path: str | Path,
*,
table: str = "cur",
namespace: int = NAMESPACE_MAIN,
base_uri: str = "https://en.wikipedia.org/wiki/",
shard: tuple[int, int] | None = None,
start_id: int = 0,
encoding: str = "latin-1",
):
if table not in ("cur", "old"):
raise ValueError("table must be 'cur' or 'old'")
self.path = Path(path)
self.table = table
self.namespace = namespace
self.base_uri = base_uri
self.source_type = f"wikipedia_{table}"
if shard is not None:
rank, total = shard
if not (0 <= rank < total) or total < 1:
raise ValueError(
f"invalid shard {shard}: need 0 <= rank < total, total >= 1"
)
self.shard = shard
# Resume support: skip rows whose id (cur_id or old_id) is <= start_id.
# The rsync-style fast-forward — already-cached docs aren't re-hashed.
self.start_id = start_id
# High-water mark observed during iteration. Caller reads this back
# after each batch and persists it via store.set_meta().
self.last_id: int = start_id
# 2003-era MediaWiki dumps mix Latin-1 raw bytes (e.g., 0xE9 for é)
# with HTML entities. Latin-1 decode is lossless on every byte and
# gives the right code point for the raw-byte cases. Modern dumps
# are UTF-8 — pass encoding="utf-8" for those.
self.encoding = encoding
[docs]
def iter_documents(self) -> Iterator[Document]:
opener = bz2.open if str(self.path).endswith(".bz2") else open
rank, total = (0, 1) if self.shard is None else self.shard
idx = 0
with opener(self.path, "rt", encoding=self.encoding, errors="replace") as f:
for stmt in _iter_insert_statements(f):
head = stmt.lstrip()
if not head.upper().startswith("INSERT INTO"):
continue
up = head.upper()
vidx = up.find("VALUES")
if vidx < 0:
continue
table_clause = head[:vidx].lower()
token = f" {self.table} "
if token not in (table_clause + " "):
continue
payload = head[vidx + len("VALUES"):]
payload = payload.rstrip().rstrip(";").rstrip()
for row in _split_values_tuples(payload):
# Stride filter applied at the row index — every shard sees
# the full SQL stream but only emits its share.
if idx % total != rank:
idx += 1
continue
idx += 1
# Resume fast-forward: skip rows already cached. We still
# parse the value tuples (needed to find row[0]) but skip
# the wikilink extraction + Document construction, which
# is the dominant per-row cost.
if self.start_id:
try:
row_id = int(row[0]) if row[0] is not None else 0
except (ValueError, TypeError):
row_id = 0
if row_id <= self.start_id:
continue
if row_id > self.last_id:
self.last_id = row_id
else:
try:
row_id = int(row[0]) if row[0] is not None else 0
if row_id > self.last_id:
self.last_id = row_id
except (ValueError, TypeError):
pass
yield from self._row_to_doc(row)
def _row_to_doc(self, row: list[str | None]) -> Iterator[Document]:
if len(row) < 4:
return
try:
ns = int(row[1]) if row[1] is not None else None
except (ValueError, TypeError):
return
if ns != self.namespace:
return
title = row[2] or ""
text = row[3] or ""
if not title or not text:
return
# Cur-only: skip rows flagged as redirects (col 10 in the 2003 schema).
if self.table == "cur":
is_redirect = False
if len(row) > 10 and row[10] is not None:
try:
is_redirect = bool(int(row[10]))
except (ValueError, TypeError):
is_redirect = False
if is_redirect:
return
edges = _extract_wikilinks(text, self.base_uri)
uri = self.base_uri + title.replace(" ", "_")
extra: dict = {}
# Surface the row id for resume / chronological ordering.
if row[0] is not None:
extra[f"{self.table}_id"] = row[0]
# `old` rows also carry a timestamp at position 7.
if self.table == "old" and len(row) > 7 and row[7]:
extra["old_timestamp"] = row[7]
yield Document(
uri=uri,
content=text,
source_type=self.source_type,
title=title,
edges=edges,
extra=extra,
)
# Backward-compatible thin wrapper.
[docs]
class WikipediaCurDump(WikipediaSqlDump):
"""Iterates a MediaWiki 'cur' table SQL dump."""
def __init__(
self,
path: str | Path,
namespace: int = NAMESPACE_MAIN,
base_uri: str = "https://en.wikipedia.org/wiki/",
):
super().__init__(
path, table="cur", namespace=namespace, base_uri=base_uri
)
[docs]
class WikipediaOldDump(WikipediaSqlDump):
"""Iterates a MediaWiki 'old' (revision history) table SQL dump."""
def __init__(
self,
path: str | Path,
namespace: int = NAMESPACE_MAIN,
base_uri: str = "https://en.wikipedia.org/wiki/",
):
super().__init__(
path, table="old", namespace=namespace, base_uri=base_uri
)