diff --git a/docs/internal/proposals/inflight-registry.md b/docs/internal/proposals/inflight-registry.md new file mode 100644 index 0000000..1982637 --- /dev/null +++ b/docs/internal/proposals/inflight-registry.md @@ -0,0 +1,301 @@ +# SQLite-Backed Inflight Item Registry + +## Problem + +The previous `JobChecker` (`exca/map.py`) had a documented TOCTOU race +(`docs/internal/debug/concurrent-writes.md`) that caused 17.5x data duplication in +production. It tracked jobs, not items, and provided no visibility into what each job +was processing. + +The goal is not perfection — it's to prevent the obvious duplicate-submission stampede +(~90% of issues) without introducing bugs. The registry is **advisory, not authoritative**: +the `CacheDict` remains the source of truth. If the registry is corrupted or unavailable, +we fall back to current behavior (no coordination). + +## Chosen Approach: SQLite with Graceful Fallback + +SQLite (`journal_mode=DELETE` for NFS safety) stored in the cache folder. stdlib +dependency, atomic transactions, queryable for debugging. + +**NFS risk mitigation**: SQLite on NFS can theoretically corrupt the DB under broken +`fcntl()` locks. But since the DB is advisory: + +- Corruption → delete and fall back to no coordination (same as today) +- Lost lock → two workers claim the same item → duplicate work (same as today, but rare) +- The actual cached data in `CacheDict` is never at risk + +## Schema + +Single table, one row per in-flight item: + +```sql +CREATE TABLE IF NOT EXISTS inflight ( + item_uid TEXT PRIMARY KEY, + pid INTEGER NOT NULL, + job_id TEXT, + job_folder TEXT, + claimed_at REAL NOT NULL +); +``` + +- `item_uid` — the cache key being processed +- `pid` — OS pid of the claiming process (always available; used for liveness + fallback before job info is recorded) +- `job_id` — nullable: Slurm job ID string, set via `update_worker_info()` after + submission for actual Slurm jobs only (`isinstance(job, submitit.SlurmJob)`) +- `job_folder` — nullable: submitit executor folder path, set alongside `job_id` +- `claimed_at` — `time.time()` timestamp, for debugging / last-resort stale detection + +## WorkerInfo + +Worker identity is represented by a frozen dataclass: + +```python +@dataclasses.dataclass(frozen=True) +class WorkerInfo: + pid: int + job_id: str | None = None + job_folder: str | None = None + claimed_at: float | None = None + + def is_alive(self) -> bool: ... + def wait(self) -> None: ... +``` + +- `is_alive()` — Slurm path if `job_id` is set (reconstructs `SlurmJob`, checks + `.done()`), otherwise PID check via `os.kill(pid, 0)`. +- `wait()` — blocking wait for Slurm jobs; no-op for local workers. + +Frozen so it can be used as a dict key for grouping liveness checks in `claim()`. + +## Liveness Checks + +Two strategies based on what's available: + +- **Slurm** (`job_id IS NOT NULL`): reconstruct `submitit.SlurmJob(job_id=job_id, + folder=job_folder)`, call `.done()`. submitit handles `sacct` throttling internally. + Only recorded for actual Slurm jobs — `DebugExecutor` and `LocalExecutor` jobs + are excluded to prevent incorrect Slurm liveness checks on non-Slurm jobs. +- **Local / pools** (`job_id IS NULL`): `os.kill(pid, 0)` — signal 0 checks process + existence without killing it. Works for submitit `LocalExecutor` (subprocess), + `ProcessPoolExecutor`, and `ThreadPoolExecutor` (all same-host, PID = parent process). + +## Core Flow + +All callers use the `inflight_session()` context manager, which encapsulates the +full lifecycle: + +``` + ┌──────────────────────────────┐ + │ inflight_session() │ + │ ┌─ retry loop ────────────┐ │ + │ │ 1. wait_for_inflight() │ │ + │ │ ├─ Slurm: .wait() │ │ + │ │ ├─ Local: poll │ │ + │ │ │ (0.5s→30s backoff)│ │ + │ │ └─ Reclaim dead │ │ + │ │ 2. claim() all-or-none │ │ + │ │ ├─ All claimable │ │ + │ │ │ → COMMIT, break │ │ + │ │ └─ Live blocker │ │ + │ │ → ROLLBACK, retry │ │ + │ └─────────────────────────┘ │ + │ 3. yield all claimed_uids │ + └──────────┬───────────────────┘ + │ + ▼ + ┌──────────────────────────────┐ + │ Caller submits work │ + │ 1. Re-check cache (refresh │ + │ after wait avoids re- │ + │ submitting done items) │ + │ 2. executor.submit() │ + │ 3. update_worker_info() │ + │ (Slurm jobs only — │ + │ records job_id + folder │ + │ for Slurm liveness) │ + │ 4. job.result() / wait │ + └──────────┬───────────────────┘ + │ + ▼ + ┌──────────────────────────────┐ + │ finally (session exit) │ + │ 1. release(claimed_uids) │ + │ 2. close() │ + └──────────────────────────────┘ +``` + +**Self-deadlock prevention**: `wait_for_inflight()` and `claim()` both skip items +owned by the current PID, so re-entrant / nested calls never block on themselves. + +**Wait behavior**: callers need all items complete before returning. For items claimed +by another live worker: + +- Slurm: reconstruct the Job and call `.wait()` +- Local/pools: poll with exponential backoff (0.5 s → 30 s cap), checking liveness + each iteration. Jitter (0–0.5 s) before the first poll de-synchronizes concurrent + callers. First INFO log after 60 s, then hourly, with item count and PIDs. + +## InflightRegistry API + +Located in `exca/cachedict/inflight.py`. + +```python +class InflightRegistry: + """Advisory SQLite registry of in-flight cache items.""" + + def __init__(self, cache_folder: Path, permissions: int | None = 0o777) -> None: + # DB at /inflight.db + ... + + def claim(self, item_uids: list[str], + pid: int | None = None) -> list[str]: + """All-or-nothing claim: COMMIT if all items claimable, ROLLBACK + otherwise. Returns all item_uids on success, or only pre-owned + items (same PID) on rollback. pid defaults to os.getpid().""" + + def update_worker_info(self, item_uids: list[str], *, + job_id: str | None = None, + job_folder: str | None = None) -> None: + """Update job_id/job_folder for already-claimed items. + Called after submission, when the Slurm job ID is known.""" + + def release(self, item_uids: list[str]) -> None: + """Remove items from the registry (done or failed).""" + + def get_inflight(self, item_uids: list[str] | None = None) -> dict[str, WorkerInfo]: + """Return claimed items with their worker info.""" + + def wait_for_inflight(self, item_uids: list[str]) -> list[str]: + """Block until items are no longer in-flight. + Reclaims items from dead workers and returns their uids.""" + + def close(self) -> None: ... +``` + +### inflight_session() context manager + +```python +@contextlib.contextmanager +def inflight_session( + registry: InflightRegistry | None, + item_uids: list[str], +) -> tp.Iterator[list[str]]: + """Wait → claim → yield claimed → release + close. + When registry is None, yields all item_uids (no-op). + The registry connection is closed on exit; callers must call + update_worker_info() inside the with block.""" +``` + +## Integration Points + +### MapInfra (submitit path: slurm / local / auto) + +- `_method_override()`: wraps the submit+wait block in `inflight_session`. +- After the session wait, re-checks `cache_dict` to skip items completed by others + during the wait (avoids needless re-submission). +- After `executor.submit()`: `registry.update_worker_info()` per chunk, but only + for actual Slurm jobs (`isinstance(j, submitit.SlurmJob)`). +- Release happens in `inflight_session`'s finally block. + +### MapInfra (pool path: threadpool / processpool) + +- `_method_override_futures()`: same `inflight_session` pattern with cache refresh. +- Uses `pid=os.getpid()` (default), no `job_id` (all same-host). + +### Steps Backend + +- `Backend.run()`: wraps compute in `inflight_session` with single-item granularity. +- Registry is only created for non-inline backends (`type(self) is not Cached`). +- After `_submit()`: `registry.update_worker_info()` for Slurm jobs only + (`isinstance(job, submitit.SlurmJob)`). + +## All-or-Nothing Claim + +`claim()` uses all-or-nothing semantics enforced at the database level: + +1. **Phase 1 (outside transaction)**: Read existing claims, perform liveness checks + grouped by `(pid, job_id, job_folder)` so that a single dead Slurm job with many + items triggers only one `sacct` round-trip. +2. **Phase 2 (`BEGIN IMMEDIATE` transaction)**: Re-read ownership, apply cached + liveness verdicts. If every item is claimable (free, dead, or already ours) → + INSERT OR REPLACE + `COMMIT`. If any item is held by a live worker → + `ROLLBACK` — nothing is written, no partial claims are visible. + +This prevents hold-and-wait deadlocks: a worker never holds a subset of items while +blocking on the rest. The `inflight_session` retry loop simply waits and re-tries +`claim()` until it succeeds — no release-on-retry needed because `ROLLBACK` ensures +nothing was written. + +On rollback, `claim()` returns only items already owned by the caller's PID +(re-entrant / nested calls). On commit, it returns all requested items. + +On NFS with broken locking, this degrades to the same duplicate-work behavior as +before (accepted failure mode). + +## Contention Hardening + +Designed for 100+ simultaneous callers (e.g., Slurm array jobs) hitting the same DB: + +- **Busy timeout: 60 s** — SQLite retries lock acquisition internally. Zero overhead + when uncontested; only blocks when another writer holds the lock. 60 s accommodates + NFS lock latency (10–100 ms per operation) × 100+ callers. +- **Transient retry with backoff**: `_safe_execute()` distinguishes between transient + lock errors (`sqlite3.OperationalError` with "locked" / "busy") and permanent errors + (corruption, schema issues). Transient errors are retried up to 3 times with random + backoff (0.5–2 s × attempt). Permanent errors trigger graceful degradation. This + prevents the failure mode where a lock timeout causes `claim()` to return all items + as "claimed" (the degradation fallback), leading to duplicate submissions. +- **Wait jitter**: `wait_for_inflight()` adds 0–0.5 s random sleep when items are + inflight, de-synchronizing callers that start simultaneously. + +## Graceful Degradation + +Every registry method wraps DB access via `_safe_execute()`: + +- Transient lock errors → retry with random backoff (up to 3 attempts) +- Permanent errors → log warning, `_try_reset()` (close + delete DB for auto-recovery), + return fallback value + +If the DB file is corrupt, `_try_reset()` deletes it so the next access recreates +a fresh DB. This ensures the registry never blocks or breaks actual computation. + +## Scope + +One `inflight.db` per **cache folder** (not per executor folder). This is the correct +scope because different experiments sharing the same cache folder is exactly the case +where coordination matters — which is what `docs/internal/debug/concurrent-writes.md` +identified as the core problem. + +The DB file is visible (no leading dot) for easy manual deletion if needed. File +permissions default to `0o777` (matching CacheDict's shared-access model) and are +applied after DB creation. + +## Same-PID Ownership + +The registry uses `os.getpid()` as the ownership identity. `claim()` treats +same-PID items as already owned (re-entrant), and `wait_for_inflight()` skips +same-PID items to prevent self-deadlock. + +### Nested release protection + +Inner `inflight_session` calls for the same items (e.g., chains, nested Steps) +must not release the outer session's claims. This is handled by tracking which +items were already owned at session entry: the `finally` block only releases +items the session actually inserted, not items inherited from an outer session. + +### Remaining limitation: PID is too broad for concurrent ownership + +PID-based identity is correct for true re-entrant calls within one logical call +stack, but too broad for independent concurrent work in the same process: + +- **Thread pools**: Workers in a `ThreadPoolExecutor` share the parent PID, so + multiple threads can all consider the same item "theirs." +- **Overlapping MapInfra instances**: Two models writing to the same cache folder + in one process share PID-based ownership. + +The consequence is duplicate work (not data corruption) — CacheDict is the source +of truth. Fixing this properly requires an `owner_token` column separate from PID, +so that ownership is narrowed to the exact session, not the whole process. This is +deferred pending evidence of significant duplicate work caused by this limitation +in practice. diff --git a/exca/cachedict/inflight.py b/exca/cachedict/inflight.py new file mode 100644 index 0000000..3daa5fd --- /dev/null +++ b/exca/cachedict/inflight.py @@ -0,0 +1,503 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Advisory SQLite registry of in-flight cache items. + +Tracks which items are being processed by which worker, enabling +concurrent processes to avoid duplicate submissions. The registry +is advisory — CacheDict remains the source of truth. If the DB is +corrupt or inaccessible, all methods degrade gracefully (log a +warning and behave as if the registry is empty). +""" + +import contextlib +import dataclasses +import functools +import logging +import os +import random +import shutil +import sqlite3 +import time +import typing as tp +from pathlib import Path + +import submitit + +logger = logging.getLogger(__name__) + +_SCHEMA = """\ +CREATE TABLE IF NOT EXISTS inflight ( + item_uid TEXT PRIMARY KEY, + pid INTEGER NOT NULL, + job_id TEXT, + job_folder TEXT, + claimed_at REAL NOT NULL +); +""" + +T = tp.TypeVar("T") + + +# -- Helpers ------------------------------------------------------------------ + + +@functools.lru_cache(maxsize=1) +def _has_sacct() -> bool: + """Check whether sacct is available (cached after first call). + + Secondary safety net: on machines without sacct (dev, CI), submitit's + SlurmJob.done() silently returns False instead of raising, making dead + jobs appear alive and causing wait_for_inflight to hang. The primary + defense is the isinstance(job, SlurmJob) gate in callers that prevents + non-Slurm job info from being recorded in the first place. + """ + return shutil.which("sacct") is not None + + +def _is_pid_alive(pid: int) -> bool: + try: + os.kill(pid, 0) + return True + except ProcessLookupError: + return False + except PermissionError: + return True + + +# -- Worker identity ---------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True) +class WorkerInfo: + """Identity of the worker that claimed an item. + + Also serves as the DB row representation when ``claimed_at`` is set. + Frozen so it can be used as a dict key for grouping liveness checks. + """ + + pid: int + job_id: str | None = None + job_folder: str | None = None + claimed_at: float | None = None + + def __post_init__(self) -> None: + # Pre-register with submitit's shared SlurmInfoWatcher so that + # batch sacct calls cover all workers created in the same + # get_inflight() result set (one sacct call instead of N). + job: tp.Any = None + if self.job_id is not None and self.job_folder is not None and _has_sacct(): + try: + job = submitit.SlurmJob(job_id=self.job_id, folder=self.job_folder) + except Exception: + pass + object.__setattr__(self, "_job", job) + + @classmethod + def _from_row( + cls, row: tuple[str, int, str | None, str | None, float] + ) -> tuple[str, "WorkerInfo"]: + """Convert a (item_uid, pid, job_id, job_folder, claimed_at) row.""" + uid, pid, job_id, job_folder, claimed_at = row + return uid, cls( + pid=pid, job_id=job_id, job_folder=job_folder, claimed_at=claimed_at + ) + + def is_alive(self) -> bool: + """Check if this worker is still running.""" + job: submitit.SlurmJob | None = self._job # type: ignore[attr-defined] + if job is not None: + try: + return not job.done() + except Exception: + return False + return _is_pid_alive(self.pid) + + def wait(self) -> None: + """Block until a Slurm job finishes (no-op for local workers).""" + job: submitit.SlurmJob | None = self._job # type: ignore[attr-defined] + if job is None: + return + try: + if not job.done(): + job.wait() + except Exception: + logger.debug("Could not wait for Slurm job %s", self.job_id, exc_info=True) + + +# -- Registry ----------------------------------------------------------------- + + +class InflightRegistry: + """Advisory SQLite registry of in-flight cache items. + + All public methods gracefully degrade: if the DB is corrupt or + inaccessible, they log a warning and behave as if the registry + is empty. + + Parameters + ---------- + cache_folder: + Path to the cache folder. The DB is stored as + ``/inflight.db``. + permissions: + File permissions applied to the DB file after creation + (mirrors CacheDict's permission handling). ``None`` to skip. + """ + + def __init__(self, cache_folder: Path | str, permissions: int | None = 0o777) -> None: + self.db_path = Path(cache_folder) / "inflight.db" + self.permissions = permissions + self._conn: sqlite3.Connection | None = None + + # -- Connection management ------------------------------------------------ + + def _connect(self) -> sqlite3.Connection: + """Lazy-open the DB connection, creating the table if needed.""" + if self._conn is not None: + return self._conn + self.db_path.parent.mkdir(parents=True, exist_ok=True) + # Autocommit (isolation_level=None): most methods rely on implicit + # per-statement transactions; claim() uses explicit BEGIN IMMEDIATE + # / COMMIT to serialize concurrent claims. + conn = sqlite3.connect( + str(self.db_path), + timeout=20, + isolation_level=None, + ) + conn.execute("PRAGMA journal_mode=DELETE") + conn.execute(_SCHEMA) + if self.permissions is not None: + try: + self.db_path.chmod(self.permissions) + except Exception: + msg = "Failed to set permissions on %s" + logger.warning(msg, self.db_path, exc_info=True) + self._conn = conn + return conn + + def _safe_connect(self) -> sqlite3.Connection | None: + """Connect with graceful fallback — returns None on failure.""" + try: + return self._connect() + except Exception: + msg = "Inflight registry unavailable at %s, proceeding without coordination" + logger.warning(msg, self.db_path, exc_info=True) + self._try_reset() + return None + + def _try_reset(self) -> None: + """Close connection and delete corrupt DB so next access recreates it.""" + self._close_conn() + try: + self.db_path.unlink(missing_ok=True) + except Exception: + pass + + def _close_conn(self) -> None: + if self._conn is not None: + try: + self._conn.close() + except Exception: + pass + self._conn = None + + def close(self) -> None: + """Close the DB connection.""" + self._close_conn() + + def _safe_execute( + self, op_name: str, fallback: T, fn: tp.Callable[[sqlite3.Connection], T] + ) -> T: + """Run *fn* against the DB connection with graceful degradation. + + Transient lock errors (``sqlite3.OperationalError`` with "locked" + or "busy") are retried with random backoff. Other errors trigger + graceful degradation (log + return fallback). + """ + conn = self._safe_connect() + if conn is None: + return fallback + for attempt in range(3): + try: + return fn(conn) + except sqlite3.OperationalError as e: + if "locked" not in str(e) and "busy" not in str(e): + break + # Rollback any aborted transaction before retrying. + try: + conn.execute("ROLLBACK") + except Exception: + pass + if attempt < 2: + delay = random.uniform(0, attempt + 1) + msg = "Inflight registry %s: lock contention, retry %d in %.1fs" + logger.debug(msg, op_name, attempt + 1, delay) + time.sleep(delay) + continue + break + except Exception: + break + logger.warning("Inflight registry %s failed", op_name, exc_info=True) + self._try_reset() + return fallback + + # -- Core operations ------------------------------------------------------ + + def claim( + self, + item_uids: list[str], + pid: int | None = None, + ) -> list[str]: + """Atomically claim all requested items, or none (except pre-owned). + + All-or-nothing semantics enforced at the database level via + ROLLBACK: if any item is held by a live worker with a different + PID, the entire transaction is rolled back and no new claims are + written. This prevents partial-claim hold-and-wait deadlocks + across concurrent sessions with overlapping item sets. + + Returns the list of item_uids actually claimed. On success this + equals *item_uids*. On rollback it contains only items already + owned by *pid* (re-entrant / nested calls). + """ + if not item_uids: + return [] + if pid is None: + pid = os.getpid() + + # Phase 1: liveness checks outside the transaction (can be slow + # for Slurm sacct calls — must not hold the DB write lock). + existing = self.get_inflight(item_uids) + alive_cache: dict[WorkerInfo, bool] = {} + for info in existing.values(): + if info.pid != pid and info not in alive_cache: + alive_cache[info] = info.is_alive() + + # Phase 2: short transaction — only SELECT + INSERT, no I/O. + # All-or-nothing: COMMIT if every item is claimable, ROLLBACK + # otherwise. This guarantees no partial claims are visible to + # other workers. + def _do(conn: sqlite3.Connection) -> list[str]: + now = time.time() + conn.execute("BEGIN IMMEDIATE") + placeholders = ",".join("?" for _ in item_uids) + rows = conn.execute( + f"SELECT item_uid, pid, job_id, job_folder, claimed_at FROM inflight " + f"WHERE item_uid IN ({placeholders})", + item_uids, + ).fetchall() + fresh = dict(WorkerInfo._from_row(r) for r in rows) + pre_owned: list[str] = [] + to_insert: list[str] = [] + for uid in item_uids: + if uid in fresh: + owner = fresh[uid] + if owner.pid == pid: + pre_owned.append(uid) + continue + if alive_cache.get(owner, True): + # Live worker blocks us — rollback everything. + conn.execute("ROLLBACK") + return pre_owned + to_insert.append(uid) + conn.executemany( + "INSERT OR REPLACE INTO inflight " + "(item_uid, pid, job_id, job_folder, claimed_at) " + "VALUES (?, ?, NULL, NULL, ?)", + [(uid, pid, now) for uid in to_insert], + ) + conn.execute("COMMIT") + return pre_owned + to_insert + + result = self._safe_execute("claim", list(item_uids), _do) + logger.debug("Claimed %d/%d items (pid=%d)", len(result), len(item_uids), pid) + return result + + def update_worker_info( + self, + item_uids: list[str], + *, + job_id: str | None = None, + job_folder: str | None = None, + ) -> None: + """Update job_id and job_folder for items already claimed. + + Called after job submission, when the Slurm job ID becomes known. + Between claim and update, liveness falls back to PID check. + """ + if not item_uids: + return + + def _do(conn: sqlite3.Connection) -> None: + conn.executemany( + "UPDATE inflight SET job_id = ?, job_folder = ? WHERE item_uid = ?", + [(job_id, job_folder, uid) for uid in item_uids], + ) + + self._safe_execute("update", None, _do) + msg = "Updated worker info for %d items (job_id=%s)" + logger.debug(msg, len(item_uids), job_id) + + def release(self, item_uids: list[str]) -> None: + """Remove items from the registry (done or failed).""" + if not item_uids: + return + + def _do(conn: sqlite3.Connection) -> None: + conn.executemany( + "DELETE FROM inflight WHERE item_uid = ?", + [(uid,) for uid in item_uids], + ) + + self._safe_execute("release", None, _do) + logger.debug("Released %d items", len(item_uids)) + + def get_inflight(self, item_uids: list[str] | None = None) -> dict[str, WorkerInfo]: + """Return claimed items with their worker info.""" + + def _do(conn: sqlite3.Connection) -> dict[str, WorkerInfo]: + query = "SELECT item_uid, pid, job_id, job_folder, claimed_at FROM inflight" + if item_uids is None: + rows = conn.execute(query).fetchall() + elif not item_uids: + return {} + else: + placeholders = ",".join("?" for _ in item_uids) + rows = conn.execute( + f"{query} WHERE item_uid IN ({placeholders})", item_uids + ).fetchall() + return dict(WorkerInfo._from_row(r) for r in rows) + + return self._safe_execute("query", {}, _do) + + def wait_for_inflight( + self, + item_uids: list[str], + ) -> list[str]: + """Block until the given items are no longer in-flight. + + For Slurm items, waits via submitit. For local items, polls with + exponential backoff (0.5 s → 30 s) until the item disappears from + the registry or the owning process dies. + + Items owned by the current process (``os.getpid()``) are silently + skipped to prevent self-deadlock in re-entrant / nested calls. + + Returns the list of item_uids that were reclaimed from dead workers + (caller should recompute these). + """ + if not item_uids: + return [] + remaining = set(item_uids) + reclaimed: list[str] = [] + my_pid = os.getpid() + + inflight = self.get_inflight(list(remaining)) + if inflight: + # Jitter to de-synchronize callers that start simultaneously + # (e.g. Slurm array jobs), reducing claim contention. + time.sleep(random.uniform(0, 0.5)) + msg = "Waiting for %d in-flight items (of %d requested)" + logger.warning(msg, len(inflight), len(item_uids)) + for uid, info in list(inflight.items()): + if info.pid == my_pid: + remaining.discard(uid) + continue + if info.job_id is not None and info.job_folder is not None: + logger.debug("Waiting for Slurm job %s (item %s)", info.job_id, uid) + info.wait() + + interval = 0.5 + next_log = time.time() + 3600.0 + while remaining: + inflight = self.get_inflight(list(remaining)) + alive_cache: dict[WorkerInfo, bool] = {} + still_waiting: set[str] = set() + dead_uids: list[str] = [] + for uid in remaining: + if uid not in inflight: + continue + info = inflight[uid] + if info.pid == my_pid: + continue + if info not in alive_cache: + alive_cache[info] = info.is_alive() + if not alive_cache[info]: + msg = "Reclaiming item %s from dead worker (pid=%d)" + logger.debug(msg, uid, info.pid) + dead_uids.append(uid) + else: + still_waiting.add(uid) + if dead_uids: + self.release(dead_uids) + reclaimed.extend(dead_uids) + remaining = still_waiting + if remaining: + now = time.time() + if now >= next_log: + pids = {inflight[u].pid for u in remaining if u in inflight} + msg = "Still waiting for %d in-flight items (pids: %s)" + logger.info(msg, len(remaining), pids) + next_log = now + 3600.0 + time.sleep(interval) + interval = min(interval * 2, 30.0) + + return reclaimed + + +# -- Public context manager --------------------------------------------------- + + +@contextlib.contextmanager +def inflight_session( + registry: InflightRegistry | None, + item_uids: list[str], +) -> tp.Iterator[list[str]]: + """Wait for in-flight items, claim available ones, release+close on exit. + + When *registry* is ``None`` (no cache folder), yields all *item_uids* + unchanged so that callers never need a ``None`` guard. + + Self-deadlock is prevented internally: ``wait_for_inflight`` skips items + owned by the current PID, and ``claim`` treats same-PID rows as already + ours. + + The registry connection is closed on exit; callers must perform any + ``update_worker_info`` calls inside the ``with`` block. + """ + if registry is None: + yield list(item_uids) + return + pid = os.getpid() + # Track items already owned by this PID before we start, so that + # the finally block only releases items this session actually inserted + # (not items inherited from an outer / re-entrant session). + pre_owned: set[str] = { + uid for uid, info in registry.get_inflight(item_uids).items() if info.pid == pid + } + # Retry loop: wait for inflight items, then claim. claim() uses + # all-or-nothing semantics (ROLLBACK if any item is held by a live + # worker), so no partial claims are ever written — no release needed + # on retry, and no hold-and-wait deadlock is possible. + while True: + reclaimed = registry.wait_for_inflight(item_uids) + if reclaimed: + logger.info("Reclaimed %d items from dead workers", len(reclaimed)) + claimed = registry.claim(item_uids, pid=pid) + if len(claimed) == len(item_uids): + break + # claim() rolled back — some items held by live workers that + # appeared between wait_for_inflight and claim (lost-claim race). + msg = "Claim race: got %d/%d items, re-waiting" + logger.info(msg, len(claimed), len(item_uids)) + time.sleep(random.uniform(0.5, 2.0)) + try: + yield claimed + finally: + to_release = [uid for uid in claimed if uid not in pre_owned] + registry.release(to_release) + registry.close() diff --git a/exca/cachedict/test_inflight.py b/exca/cachedict/test_inflight.py new file mode 100644 index 0000000..84f64ca --- /dev/null +++ b/exca/cachedict/test_inflight.py @@ -0,0 +1,202 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import stat +from pathlib import Path + +import pytest + +from .inflight import InflightRegistry, WorkerInfo, inflight_session + + +def test_registry_operations(tmp_path: Path) -> None: + reg = InflightRegistry(tmp_path) + pid = os.getpid() + dead_pid = 2**20 + 7 + + # Claim, query, re-entrant claim + claimed = reg.claim(["a", "b", "c"], pid=pid) + assert set(claimed) == {"a", "b", "c"} + assert set(reg.get_inflight(["a", "b", "c"])) == {"a", "b", "c"} + assert set(reg.claim(["a", "b", "c"], pid=pid)) == {"a", "b", "c"} + + # Update worker info (post-submission update) + reg.update_worker_info(["a", "b"], job_id="12345", job_folder="/logs") + info = reg.get_inflight(["a", "b"]) + assert info["a"].job_id == "12345" and info["b"].job_folder == "/logs" + + # Release subset, verify remainder + reg.release(["a", "b"]) + assert list(reg.get_inflight(["a", "b", "c"])) == ["c"] + reg.release(["c"]) + assert reg.get_inflight(["a", "b", "c"]) == {} + + # Dead worker reclaim via claim() + reg.claim(["x"], pid=dead_pid) + claimed = reg.claim(["x"], pid=pid) + assert claimed == ["x"] and reg.get_inflight(["x"])["x"].pid == pid + + # Live conflict: cannot steal from a live worker + reg.claim(["y"], pid=pid) + other = InflightRegistry(tmp_path) + assert other.claim(["y"], pid=dead_pid) == [] + other.close() + + reg.release(["x", "y"]) + reg.close() + + +def test_graceful_degradation(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: + """Corrupt, permission-denied, or deleted DB -> no crash, auto-recovery.""" + db_path = tmp_path / "inflight.db" + + # Seed the DB + reg = InflightRegistry(tmp_path) + reg.claim(["warmup"]) + reg.release(["warmup"]) + reg.close() + assert db_path.exists() + + for break_mode in ("corrupt", "permissions", "delete"): + if break_mode == "corrupt": + db_path.write_bytes(b"NOT A SQLITE DB") + elif break_mode == "permissions": + db_path.chmod(0o000) + elif break_mode == "delete": + if db_path.exists(): + db_path.unlink() + + reg2 = InflightRegistry(tmp_path) + with caplog.at_level(logging.WARNING): + reg2.claim(["a"]) + reg2.get_inflight(["a"]) + reg2.release(["a"]) + reg2.close() + + if break_mode == "permissions": + db_path.chmod(stat.S_IRWXU) + + # Auto-recovery: next access recreates a working DB + reg3 = InflightRegistry(tmp_path) + claimed = reg3.claim(["recovered"]) + assert claimed == ["recovered"] + assert "recovered" in reg3.get_inflight(["recovered"]) + reg3.release(["recovered"]) + reg3.close() + + +def test_inflight_session(tmp_path: Path) -> None: + # None registry -> yields all UIDs unchanged + with inflight_session(None, ["a", "b"]) as claimed: + assert claimed == ["a", "b"] + + # Normal flow: claims visible during session, released after + reg = InflightRegistry(tmp_path) + with inflight_session(reg, ["x", "y"]) as claimed: + assert set(claimed) == {"x", "y"} + check = InflightRegistry(tmp_path) + assert set(check.get_inflight(["x", "y"])) == {"x", "y"} + check.close() + check2 = InflightRegistry(tmp_path) + assert check2.get_inflight(["x", "y"]) == {} + check2.close() + + # Exception path: items still released in finally + reg2 = InflightRegistry(tmp_path) + with pytest.raises(ValueError, match="boom"): + with inflight_session(reg2, ["a"]) as claimed: + assert claimed == ["a"] + raise ValueError("boom") + check3 = InflightRegistry(tmp_path) + assert check3.get_inflight(["a"]) == {} + check3.close() + + # Nested / re-entrant: inner session must NOT release outer's claim + outer_reg = InflightRegistry(tmp_path) + with inflight_session(outer_reg, ["z"]) as outer_claimed: + assert outer_claimed == ["z"] + inner_reg = InflightRegistry(tmp_path) + with inflight_session(inner_reg, ["z"]) as inner_claimed: + assert inner_claimed == ["z"] + check4 = InflightRegistry(tmp_path) + assert "z" in check4.get_inflight(["z"]), "inner released outer's claim" + check4.close() + check5 = InflightRegistry(tmp_path) + assert check5.get_inflight(["z"]) == {} + check5.close() + + +def test_wait_for_inflight(tmp_path: Path) -> None: + dead_pid = 2**20 + 7 + + # Dead worker: wait detects dead PID and reclaims + reg = InflightRegistry(tmp_path) + reg.claim(["stale"], pid=dead_pid) + reg2 = InflightRegistry(tmp_path) + assert reg2.wait_for_inflight(["stale"]) == ["stale"] + assert reg2.get_inflight(["stale"]) == {} + reg2.close() + reg.close() + + # Non-Slurm job with fake job_id: must not hang, reclaimed as dead. + # Regression test for the case where a non-Slurm submitit job + # (DebugExecutor/LocalExecutor) accidentally gets job_id recorded. + reg_ns = InflightRegistry(tmp_path) + reg_ns.claim(["non_slurm"], pid=dead_pid) + reg_ns.update_worker_info(["non_slurm"], job_id="99999", job_folder="/nonexistent") + info = reg_ns.get_inflight(["non_slurm"])["non_slurm"] + assert not info.is_alive(), "fake Slurm job should not appear alive" + reg_ns2 = InflightRegistry(tmp_path) + assert reg_ns2.wait_for_inflight(["non_slurm"]) == ["non_slurm"] + reg_ns2.close() + reg_ns.close() + + # Own PID: skipped to prevent self-deadlock + reg3 = InflightRegistry(tmp_path) + reg3.claim(["mine"]) + assert reg3.wait_for_inflight(["mine"]) == [] + assert "mine" in reg3.get_inflight(["mine"]) + reg3.release(["mine"]) + reg3.close() + + +def test_inflight_session_retries_lost_claim( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """When another worker grabs an item between wait and claim, + inflight_session must re-wait instead of silently skipping.""" + competitor_pid = 2**20 + 13 + wait_calls = 0 + alive_calls = 0 + original_wait = InflightRegistry.wait_for_inflight + original_is_alive = WorkerInfo.is_alive + + def wait_then_inject(self: InflightRegistry, item_uids: list[str]) -> list[str]: + nonlocal wait_calls + result = original_wait(self, item_uids) + wait_calls += 1 + if wait_calls == 1: + rival = InflightRegistry(tmp_path) + rival.claim(["x"], pid=competitor_pid) + rival.close() + return result + + def patched_is_alive(self: WorkerInfo) -> bool: + nonlocal alive_calls + if self.pid == competitor_pid: + alive_calls += 1 + return alive_calls == 1 # alive first check, dead on retry + return original_is_alive(self) + + monkeypatch.setattr(InflightRegistry, "wait_for_inflight", wait_then_inject) + monkeypatch.setattr(WorkerInfo, "is_alive", patched_is_alive) + + reg = InflightRegistry(tmp_path) + with inflight_session(reg, ["x"]) as claimed: + assert claimed == ["x"] + assert wait_calls >= 2, f"expected retry, got {wait_calls} wait calls" diff --git a/exca/map.py b/exca/map.py index 95f09ee..9d544ff 100644 --- a/exca/map.py +++ b/exca/map.py @@ -11,19 +11,16 @@ import itertools import logging import os -import pickle import typing as tp -import uuid from concurrent import futures from pathlib import Path import numpy as np import pydantic import submitit -from submitit.core import utils from . import base, slurm -from .cachedict import CacheDict +from .cachedict import CacheDict, inflight from .utils import ShortItemUid @@ -74,45 +71,6 @@ def __call__(self, items: tp.Sequence[tp.Any]) -> tp.Iterator[tp.Any]: return self.infra._method_override(items) -class JobChecker: - """Keeps a record of running jobs in a folder - and enables waiting for them to complete. - """ - - def __init__(self, folder: Path | str) -> None: - basefolder = utils.JobPaths.get_first_id_independent_folder(folder) - self.folder = basefolder / "running-jobs" - - def add(self, jobs: tp.Iterable[tp.Any]) -> None: - """Add jobs to the list of running jobs""" - self.folder.mkdir(exist_ok=True, parents=True) - for job in jobs: - if not job.done(): - job_path = self.folder / (uuid.uuid4().hex[:8] + ".pkl") - with job_path.open("wb") as f: - pickle.dump(job, f) - - def wait(self) -> bool: - """Wait for completion of running jobs""" - waited = False - for fp in self.folder.glob("*.pkl"): - try: # avoid concurrency issues with deleted items - with fp.open("rb") as f: - job: tp.Any = pickle.load(f) - except Exception: # pylint: disable=broad-except - continue - if not job.done(): - msg = "Waiting for completion of pre-existing map job: %s\nin '%s'" - logger.info(msg, job, self.folder) - job.wait() - waited = True - # delete the file as it is not needed anymore - fp.unlink(missing_ok=True) - if waited: - logger.info("Waiting is over") - return waited - - def to_chunks( items: list[X], *, max_chunks: int | None, min_items_per_chunk: int = 1 ) -> tp.Iterator[list[X]]: @@ -247,6 +205,14 @@ def cache_dict(self) -> CacheDict[tp.Any]: state.cache_dict = cd return cd + def _inflight_registry(self) -> inflight.InflightRegistry | None: + """Create an InflightRegistry for the current cache folder, or None.""" + cache_folder = self.uid_folder() + if cache_folder is None: + return None + perm = self.permissions if isinstance(self.permissions, int) else None + return inflight.InflightRegistry(cache_folder, permissions=perm) + # pylint: disable=unused-argument def apply( self, @@ -332,9 +298,6 @@ def _find_missing(self, items: dict[str, tp.Any]) -> dict[str, tp.Any]: if not state.checked_configs: self._check_configs(write=True) if self.mode == "force": - # remove any item already computed, but not items being computed - # in another process (waited for by JobChecker) - # will not be removed to_remove = set(items) - set(missing) - state.recomputed if to_remove: msg = "Clearing %s items for %s (infra.mode=%s)" @@ -348,13 +311,6 @@ def _find_missing(self, items: dict[str, tp.Any]) -> dict[str, tp.Any]: if missing: if self.mode == "read-only": raise RuntimeError(f"{self.mode=} but found {len(missing)} missing items") - executor: submitit.Executor | None = self.executor() - if executor is not None: # wait for items being computed - jcheck = JobChecker(folder=executor.folder) - jcheck.wait() - # update cache dict and recheck as actual checking for keys updates the dict - keys = set(self.cache_dict) # update cache dict - missing = {k: item for k, item in missing.items() if k not in keys} if len(items) == len(missing) == 1 and self.forbid_single_item_computation: key, item = next(iter(missing.items())) raise RuntimeError( @@ -402,34 +358,68 @@ def _method_override(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Iterator[tp.An raise RuntimeError(f"Executor is None for {self.cluster!r}") # avoid processing same files at same time if several jobs overlap np.random.shuffle(missing) - # run on cluster - jobs = [] - chunks = list( - to_chunks( - [ki[1] for ki in missing], - max_chunks=self.max_jobs, - min_items_per_chunk=self.min_samples_per_job, - ) - ) - executor.update_parameters(slurm_array_parallelism=len(chunks)) - with self._work_env(), executor.batch(): # submitit>=1.4.6 - for chunk in chunks: - # select a batch/chunk of samples_per_job items to send to a job - j = executor.submit(self._call_and_store, chunk, use_cache_dict=True) - jobs.append(j) - jcheck = JobChecker(folder=executor.folder) - jcheck.add(jobs) - # pylint: disable=expression-not-assigned - uid = self.uid() - msg = "Sent %s samples for %s into %s jobs on cluster '%s' (eg: %s)" - logger.info( - msg, len(missing), uid, len(jobs), executor.cluster, jobs[0].job_id - ) - [j.result() for j in jobs] # wait for processing to complete - logger.info("Finished processing %s samples for %s", len(missing), uid) - folder = self.uid_folder() - if folder is not None: - os.utime(folder) # make sure the modified time is updated + registry = self._inflight_registry() + with inflight.inflight_session( + registry, [k for k, _ in missing] + ) as claimed_uids: + claimed_set = set(claimed_uids) + # Re-check cache after wait: other workers may have completed + # items while we were blocked in inflight_session. + if self.folder is not None: + keys = set(self.cache_dict) + missing = [ + (k, item) + for k, item in missing + if k in claimed_set and k not in keys + ] + else: + missing = [(k, item) for k, item in missing if k in claimed_set] + if missing: + jobs: list[tp.Any] = [] + uid_item_chunks = list( + to_chunks( + missing, + max_chunks=self.max_jobs, + min_items_per_chunk=self.min_samples_per_job, + ) + ) + executor.update_parameters( + slurm_array_parallelism=len(uid_item_chunks) + ) + with self._work_env(), executor.batch(): # submitit>=1.4.6 + for chunk in uid_item_chunks: + j = executor.submit( + self._call_and_store, + [item for _, item in chunk], + use_cache_dict=True, + ) + jobs.append(j) + if registry is not None: + for chunk, j in zip(uid_item_chunks, jobs): + if isinstance(j, submitit.SlurmJob): + registry.update_worker_info( + [uid for uid, _ in chunk], + job_id=str(j.job_id), + job_folder=str(j.paths.folder), + ) + # pylint: disable=expression-not-assigned + uid = self.uid() + msg = "Sent %s samples for %s into %s jobs on cluster '%s' (eg: %s)" + logger.info( + msg, + len(missing), + uid, + len(jobs), + executor.cluster, + jobs[0].job_id, + ) + [j.result() for j in jobs] # wait for processing to complete + logger.info( + "Finished processing %s samples for %s", len(missing), uid + ) + folder = self.uid_folder() + if folder is not None: + os.utime(folder) # make sure the modified time is updated cache_dict = self.cache_dict msg = "Recovering %s items for %s from %s" logger.debug(msg, len(items), self._factory(), cache_dict) @@ -455,47 +445,62 @@ def _method_override_futures(self, items: tp.Sequence[tp.Any]) -> tp.Iterator[tp pool = None # avoid processing same files at same time if several jobs overlap np.random.shuffle(missing) - if pool is None: - # run locally - msg = "Computing %s missing items" - logger.debug(msg, len(missing)) - cached = self.folder is not None - out = self._call_and_store( - [ki[1] for ki in missing], use_cache_dict=cached - ) - elif pool not in ("processpool", "threadpool"): - raise RuntimeError(f"Unexpected pool {pool!r}") - else: - ExecutorCls = ( - futures.ThreadPoolExecutor - if pool == "threadpool" - else futures.ProcessPoolExecutor - ) - jobs = [] - cpus = max(1, (os.cpu_count() or 1) - 1) - max_workers = min(len(missing), cpus) - if self.max_jobs is not None: - max_workers = min(max_workers, self.max_jobs) - with ExecutorCls(max_workers=max_workers) as ex: - mitems = [ki[1] for ki in missing] - chunks = to_chunks(mitems, max_chunks=3 * max_workers) - for chunk in chunks: - j = ex.submit( - self._call_and_store, - chunk, - use_cache_dict=self.folder is not None, - ) - jobs.append(j) - uid = self.uid() - msg = "Sent %s items for %s into a %s" - logger.info(msg, len(missing), uid, pool) - iterator = _set_tqdm(futures.as_completed(jobs), total=len(jobs)) - for job in iterator: - out.update(job.result()) # raise asap - logger.info("Finished processing %s items for %s", len(missing), uid) - folder = self.uid_folder() - if folder is not None: - os.utime(folder) # make sure the modified time is updated + with inflight.inflight_session( + self._inflight_registry(), [k for k, _ in missing] + ) as claimed_uids: + claimed_set = set(claimed_uids) + # Re-check cache after wait: other workers may have completed + # items while we were blocked in inflight_session. + if self.folder is not None: + keys = set(self.cache_dict) + missing = [ + (k, item) + for k, item in missing + if k in claimed_set and k not in keys + ] + else: + missing = [(k, item) for k, item in missing if k in claimed_set] + if pool is None and missing: + msg = "Computing %s missing items" + logger.debug(msg, len(missing)) + cached = self.folder is not None + out = self._call_and_store( + [ki[1] for ki in missing], use_cache_dict=cached + ) + elif missing: + if pool not in ("processpool", "threadpool"): + raise RuntimeError(f"Unexpected pool {pool!r}") + ExecutorCls = ( + futures.ThreadPoolExecutor + if pool == "threadpool" + else futures.ProcessPoolExecutor + ) + jobs: list[futures.Future[tp.Any]] = [] + cpus = max(1, (os.cpu_count() or 1) - 1) + max_workers = min(len(missing), cpus) + if self.max_jobs is not None: + max_workers = min(max_workers, self.max_jobs) + with ExecutorCls(max_workers=max_workers) as ex: + mitems = [ki[1] for ki in missing] + chunks = to_chunks(mitems, max_chunks=3 * max_workers) + for chunk in chunks: + j = ex.submit( + self._call_and_store, + chunk, + use_cache_dict=self.folder is not None, + ) + jobs.append(j) + uid = self.uid() + msg = "Sent %s items for %s into a %s" + logger.info(msg, len(missing), uid, pool) + iterator = _set_tqdm(futures.as_completed(jobs), total=len(jobs)) + for job in iterator: + out.update(job.result()) # raise asap + logger.info("Finished processing %s items for %s", len(missing), uid) + if missing: + folder = self.uid_folder() + if folder is not None: + os.utime(folder) try: cache_dict = self.cache_dict except ValueError: # no caching diff --git a/exca/steps/backends.py b/exca/steps/backends.py index 2ac88e5..8877946 100644 --- a/exca/steps/backends.py +++ b/exca/steps/backends.py @@ -26,6 +26,7 @@ import exca from exca import utils +from exca.cachedict import inflight if tp.TYPE_CHECKING: from .base import Step @@ -429,10 +430,24 @@ def run(self, func: tp.Callable[..., tp.Any], *args: tp.Any) -> tp.Any: logger.debug("Recovering job: %s", self.paths.job_pkl) if job is None: - wrapper = _CachingCall(func, self.paths, self.cache_type) - job = self._submit(wrapper, *args) + item_uid = self.paths.item_uid + registry: inflight.InflightRegistry | None = None + if type(self) is not Cached: + registry = inflight.InflightRegistry(self.paths.cache_folder) + with inflight.inflight_session(registry, [item_uid]) as claimed: + if claimed and self._cache_status() is None: + wrapper = _CachingCall(func, self.paths, self.cache_type) + job = self._submit(wrapper, *args) + if isinstance(job, submitit.SlurmJob) and registry is not None: + registry.update_worker_info( + [item_uid], + job_id=str(job.job_id), + job_folder=str(job.paths.folder), + ) + job.result() + return self._load_cache() - job.result() # Wait (result is cached, not returned) + job.result() return self._load_cache() def _submit(self, wrapper: _CachingCall, *args: tp.Any) -> tp.Any: diff --git a/exca/test_map.py b/exca/test_map.py index 3563098..f055239 100644 --- a/exca/test_map.py +++ b/exca/test_map.py @@ -117,8 +117,12 @@ def test_map_infra_cache_dict_calls(tmp_path: Path) -> None: assert max(r.readings for r in cd._jsonl_readers.values()) == 1 _ = list(whatever.process([2, 3, 4])) assert max(r.readings for r in cd._jsonl_readers.values()) == 1 - _ = list(whatever.process([5])) - assert max(r.readings for r in cd._jsonl_readers.values()) == 2 + for _ in range(2): + _ = list(whatever.process([5])) + # +2 reads: _find_missing + inflight cache re-check + assert max(r.readings for r in cd._jsonl_readers.values()) == 3 + _ = list(whatever.process([6])) + assert max(r.readings for r in cd._jsonl_readers.values()) == 5 def test_missing_yield() -> None: