RFC 0004 — Multi-GPU Tile Pipeline¶
| Status | Implemented in v0.3 |
| Author | OpenLithoHub maintainers |
| Created | 2026-05-20 |
| Targets | v0.3 |
| Related | openlithohub.workflow.tiling, openlithohub.workflow.parallel, openlithohub.cli.optimize_cmd, RFC 0001 |
Summary¶
Today openlithohub optimize run processes a tiled layout sequentially —
one tile at a time, on a single device. For full-chip designs (M1 GDS in
the GB range) this is the dominant wall-clock cost. This RFC scopes how
to scale that loop across multiple GPUs (and eventually nodes) without
touching the model layer or breaking the ONNX/TorchScript export path.
This is a research RFC. No code lands from this document; it picks a direction so a follow-up implementation RFC can be narrow.
Current state (factual)¶
Verified against main at 2026-05-20:
- Tile loop is a sequential
for tile in tiles: model.predict(tile)incli/optimize_cmd.py:142–145. No batching, no async. - Tile geometry already handles halo overlap and ramp-blended stitching
(
workflow/tiling.py:11–181). Tiles are independent at inference time — the only cross-tile coupling is the post-hoc stitch. - Model contract (
models/base.py:46–56, 64–80): predict(design, **kwargs) -> PredictionResultaccepts(H, W)or(B, C, H, W)and unsqueezes internally.to_torch_module() -> nn.Modulereturns a barenn.Modulefor export — not wrapped, not on any accelerator.- Device placement is explicit, per-model:
.to(self._device)/device=kwarg (neural_ilt.py:47,97,levelset_ilt.py:105–107). No.cuda(), no global device manager. torch.compileis applied conditionally to the Hopkins forward kernel only (levelset_ilt.py:127), with a cache key includingstr(target.device)so per-device compilation works.- Existing distributed infra: none. No
torch.distributed,torch.multiprocessing,DataParallel,DistributedDataParallel,accelerate, orrayinsrc/. - GPU-gated tests: none.
tests/test_workflow/has notorch.cuda.is_availableskips and no distributed harness.
Hard constraints¶
- Export must keep working.
cli/export_cmd.py:105callslitho_model.to_torch_module()and expects a barenn.Module. Any solution that wraps the model itself (e.g.accelerator.prepare(model),DistributedDataParallel(model)) breaks ONNX export. Therefore parallelism must wrap the tile loop, not the model. - Single-GPU and CPU paths are the default (Colab, laptops, CI). Multi-GPU is opt-in. No new required dependency.
torch.compilecache should not be invalidated. Cache is keyed by device string; one compile per worker is acceptable, recompiling per tile is not.- The model layer stays untouched. No
models/*.pychanges for parallelism.
Options considered¶
Option A — torch.multiprocessing.spawn over tile shards¶
Spawn N worker processes, shard tiles round-robin, each worker pins
itself to one GPU and runs the existing per-tile predict loop, results
go back via a Queue and the main process stitches.
- + Zero new runtime dependencies (stdlib + torch).
- + Model layer untouched; export path untouched.
- +
torch.compilecache lives per-worker — natural fit. - + Halo/stitch logic stays in the main process; workers see independent tiles.
- − We hand-roll worker lifecycle, error propagation, cancellation, KeyboardInterrupt forwarding. Not free, but ~150 lines bounded.
- − No multi-node story (need MPI /
torchrunfor that, future work). - − Model weights are reloaded per worker — RAM cost ≈ N × weights. Acceptable for current model sizes (NeuralILT < 100 MB), revisit if a backbone lands from RFC 0001.
Option B — Hugging Face Accelerate¶
Use Accelerator + accelerator.prepare() to manage devices, optionally
launch via accelerate launch.
- + Battle-tested, mature, multi-node out of the box.
- + Mixed-precision and gradient-accumulation primitives we may want later for training-time pipelines.
- − Idiomatic Accelerate wraps the model — exactly what our
hard constraint #1 forbids. We could use Accelerate purely for its
process-group bootstrap and skip
prepare(model), but at that point we are using ~10% of the library. - − Adds a non-trivial dependency (
acceleratepullshuggingface_hub,psutil, etc.) for what would, in our usage, be a thin process launcher. - − Users now need to know whether to invoke
python -m openlithohub …oraccelerate launch -m openlithohub …. Two entry points is a documentation tax. - − Export path testing burden: every release we have to verify
acceleratedid not silently start wrapping our model.
Option C — Ray Core¶
Ray actors, one per GPU, tiles.map(actor.predict).
- + Cleanest multi-node story; cluster-aware scheduling for free.
- + Decouples tile scheduling from model code completely.
- − Heaviest dependency by a wide margin (Ray + grpc + plasma store). An order of magnitude more install footprint than A or B.
- − Operational complexity (Ray head node, dashboard, etc.) is overkill for a single-node 8-GPU box, which is the realistic near-term target.
- − No current users have asked for cluster scheduling. Premature.
Option D — Status quo + advisory docs¶
Document that single-GPU is the supported configuration; users who want parallelism shard the input GDS themselves and run N CLIs.
- + Zero new code.
- − The whole point of the "industrial software" pitch is that we handle this for them.
Recommendation¶
Option A (torch.multiprocessing.spawn) for v0.3, with Option C
(Ray) kept on the table for v0.4+ if multi-node demand materialises.
Reject Option B (Accelerate) for this use case — its core value-add
is model wrapping, which our export constraint forbids; using it as a
process launcher is paying a dependency for very little.
Why not Accelerate¶
The Accelerate library shines when training is the workload: model sharding, mixed-precision, gradient accumulation, FSDP. Our workload is inference over independent tiles — embarrassingly parallel, with a hard "do not wrap the model" constraint from the export path. Using Accelerate here means importing the library and then opting out of its main API. That is a smell.
Why not Ray yet¶
Ray is the right answer when the unit of parallelism is a cluster. Today the unit is "an 8-GPU workstation" or "an 8-GPU node in a fab LSF queue." A 2 MB stdlib dependency is a much better fit than a 200 MB Ray cluster dependency for that target. Re-evaluate when a user actually asks for multi-node — not before.
Minimal landing path (for the implementation RFC)¶
The follow-up RFC should bound the work to:
- CLI surface: add
--num-gpus Ntocli/optimize_cmd.py.N=1(default) keeps the current sequential path bit-for-bit identical.N>1enters the multiprocessing path. - New module:
src/openlithohub/workflow/parallel.py(~150 lines) parallel_tile_inference(model_factory, tiles, num_gpus) -> list[(Tile, Tensor)]- Workers receive a factory (model name + kwargs), not a live model object — keeps pickling sane and avoids CUDA-context-fork hazards.
mp.spawnwithstart_method='spawn'(notfork— CUDA + fork is undefined behaviour).- No new runtime dependency. Stdlib + torch only.
- Tests (
tests/test_workflow/test_parallel.py): - CPU-only smoke test using
num_gpus=2mapped tocpu(validate the dispatch logic without needing GPUs in CI). - GPU test gated on
torch.cuda.device_count() >= 2, marked@pytest.mark.gpu, skipped in default CI. - Regression coverage for the unchanged path:
- Existing tiling tests must pass with
num_gpus=1(the default). - Export path test (
test_cli/test_export.py, if present) must pass unchanged — the model layer is not touched. - Docs: one new section in
docs/cli-reference.mdunderoptimize run, plus a one-paragraph mention indocs/architecture.md. No new top-level page.
Out of scope for v0.3¶
- Multi-node (Ray /
torchrun/ Slurm). - Mixed-precision inference.
- Streaming OASIS reader (the assumption is the layout fits in host RAM as a tensor; this is the same assumption today).
- Per-GPU dynamic load balancing (we shard tiles round-robin; if some tiles are slower than others, c'est la vie for v0.3).
Open questions for the implementation RFC¶
- Should workers each load their own model weights, or should the main process load once and pass via shared memory? (Probably "each loads" — shared memory + CUDA is fragile.)
- How to surface per-worker
torch.compilewarmup time in the progress bar? (Probably ignore for v0.3 — first tile per worker is slow, subsequent ones aren't.) - KeyboardInterrupt propagation —
mp.spawnswallows signals in subtle ways. Worth a dedicated test.
Implementation (landed v0.3)¶
- New module:
src/openlithohub/workflow/parallel.py—parallel_tile_inference(model_name, model_kwargs, tiles, *, num_gpus, base_perf_kwargs, progress_cb). Workers receive a registry name + kwargs (factory, not a live model) and re-instantiate viaregister_builtin_models()+registry.get(...). Spawn context (mp.get_context("spawn")), manualctx.Processrather thanmp.spawnfor cleaner error / KeyboardInterrupt propagation. - CLI:
--num-gpus Nadded tooptimize run(cli/optimize_cmd.py).N=1(default) keeps the sequential path verbatim. - Models layer untouched. Export path
(
cli/export_cmd.py:litho_model.to_torch_module()) is unaffected. - Tests:
tests/test_workflow/test_parallel.pycovers CPU dispatch, sequential-vs-parallel parity, error propagation, round-robin balancing, default-path regression, and a GPU-gated (@pytest.mark.gpu) two-device test. - Resolved open questions:
- Q1: workers each load (no shared memory).
- Q2: progress advances on result drain — workers each pay one
torch.compilewarmup, surfaced as a slow first tile per worker. - Q3: covered by
test_parallel_worker_error_propagatesand theKeyboardInterruptbranch inparallel_tile_inference.