feat(phase_thick_3d_tilt): Re-factor of ops_process PR#99 into OPS agnostic waveorder optimizations.#562
Open
mark-a-potts wants to merge 8 commits into
Open
Conversation
`_compute_shared_optics` materialized the propagation kernel and Green's
function via CPU `torch.exp`, then `.to(device)`d the result -- wasting
~1.28 s/call at OPS dims while a GPU sat idle. Adding a `device=` kwarg
that threads down to `util.generate_frequencies` and `torch.arange`
moves the build directly to the caller's device. `calculate_transfer_function`
now passes `device=zen.device`, picking up the speedup automatically; the
trailing `.to(device)` calls are retained as no-op guards for any external
`_compute_shared_optics` override that doesn't honor `device=`.
Bench (RTX 6000 Pro Blackwell, OPS dims `(40, 512, 512)` + z_padding=10):
baseline (cpu build + .to): p50 1284.78 ms
new (device build): p50 1.15 ms 1113x
Numerical equivalence (CPU vs CUDA build, float32):
fyy, fxx: bit-identical (max_abs_diff 0.0)
det_pupil: 5.1e-6
propagation_kernel: 6.1e-5 <-- bounded by CUDA torch.exp precision
greens_function_z: 1.8e-6
Pearson: >= 0.999999 on all five tensors
This is the upstream-able core of OPS Strand C (the `_install_gpu_shared_optics_patch`
monkey-patch in `ops_process/reconstruct_tilt_corrected.py`). Full-scale ops0042
7035-position run dropped 89 min -> 49.7 min with that patch in place, GPU SM
util 13% -> 78%; landing this upstream lets every waveorder consumer pick up
the same win without monkey-patching.
Tests:
- `test_compute_shared_optics_default_is_cpu` - back-compat: `device=None` still
materializes on CPU.
- `test_compute_shared_optics_device_str_cpu` - string `"cpu"` accepted.
- `test_compute_shared_optics_cuda_matches_cpu` - CUDA build agrees with CPU within
float32 transcendental precision (max_abs_diff < 1e-3, Pearson >= 0.999999).
- `test_calculate_transfer_function_device_threading` - CUDA-resident tilt angles
produce CUDA-resident TFs end-to-end.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…reconstruction Two extensions to `_optimize_gradient` that the OPS tilt-recon work needs upstream: 1. **Per-tile initial values.** In batched mode (`data.ndim == 4`), `optimizable_params[name][0]` may now be a `(B,)` tensor of per-tile starting points (e.g. from a calibration warm-start map), not just a scalar broadcast across the batch. A 0-d tensor still works and broadcasts the legacy way. 2. **Frozen parameters via `lr == 0`.** A parameter with learning rate 0 is held at its initial value across iterations: it's still passed to `reconstruct_fn` (so the forward model sees the per-tile prior) but excluded from the Adam param-groups and given `requires_grad=False`. This is the `z-only` tilt refinement recipe: `tilt_angle_zenith` and `tilt_angle_azimuth` pinned to map-derived priors, only `z_focus_offset` moves. At least one parameter must be free; otherwise the call raises. Both features are fully backwards-compatible: scalar `init_val` + `lr > 0` behaves identically to the previous implementation. Tests ----- - `test_batched_optimization_independent_tiles` — B tiles converge to B independent targets in one batched call (existing batched behavior, now explicitly covered). - `test_per_tile_initial_value_tensor` — `(B,)` tensor init lands each tile near its individual target. - `test_frozen_axis_does_not_move` — frozen scalar param stays at init. - `test_per_tile_init_with_frozen_param` — per-tile init + freeze combination: frozen tensor retains per-tile values; free param picks up the slack. - `test_all_frozen_raises` — degenerate "every param frozen" config is rejected with a clear ValueError. - `test_per_tile_init_shape_mismatch_raises` — wrong-shape per-tile init in batched mode is rejected. All 16 tests in `tests/optim/test_optimize.py` pass; full `tests/optim/` and `tests/models/` suites still pass (95 passed, 2 CUDA-skipped). Source: OPS-side `_gpu_optimize_tilt_params` in `ops_process/ops_analysis/processes/reconstruct_tilt_corrected.py:1110` which currently handles both features via `OPS_TILT_FREEZE_*` env vars and an in-process warm-start dict. After this lands, the ops_process adapter shrinks to env-var → kwargs translation. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…nly halves `_compute_shared_optics` always recomputed every tensor — both the angle-fixed ones (fyy, fxx, radial_frequencies, det_pupil) and the z-dependent ones (propagation_kernel, greens_function_z) — even when callers only varied z across iterations. Factor the function into three helpers, all back-compat: - `_compute_angle_optics(yx_shape, yx_pixel_size, wavelength_illumination, numerical_aperture_detection, pupil_steepness, device)` Returns the four tensors that don't depend on z. Build once per position (or fewer times if shape/NA/wavelength are also constant). - `_compute_z_position_list(z_shape, z_pixel_size, z_padding, invert_phase_contrast, device)` Pulled the z-list construction out so callers can rebuild only this when only z varies. - `_compute_z_optics(radial_frequencies, det_pupil, z_position_list, wavelength_illumination, index_of_refraction_media)` Returns the propagation kernel + Green's function. Re-call per optimizer iteration in z-only tilt-recon. `_compute_shared_optics` is preserved as a thin wrapper that composes all three; its output is unchanged (verified by `test_angle_z_split_composes_to_shared_optics`, which compares the new composed call against the legacy one for bitwise equality). Motivating use case: the OPS `FREEZE_ANGLES=1` tilt-recon recipe (per-position warm-start + 3-8 optimizer iterations). Today each Adam step rebuilds the entire optics from scratch. With the split, callers cache the angle half once per position and re-call only the z half per iter -- ~50% of per-iter optics build cost reclaimed for the cost of a few cached tensors. Tests ----- - `test_angle_z_split_composes_to_shared_optics` -- new helpers compose to bit-identical legacy output. - `test_angle_optics_cached_across_z_changes` -- angle outputs are invariant to z config; the cache is correct to hold. All 10 phase_thick_3d tests pass (CPU); 2 CUDA-gated tests skipped on the login node, validated previously. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ndent halves Companion change to the phase_thick_3d split (commit e66db59). The OPS tilt-recon optimizer's hot loop calls `isotropic_thin_3d.reconstruct(...)` per iter, not `phase_thick_3d` -- so the FREEZE_ANGLES caching benefit needs the same factoring here. Split `_calculate_wrap_unsafe_transfer_function` into three helpers (all back-compat: the wrapper still produces bit-identical output): - `_compute_angle_optics(yx_shape, yx_pixel_size, wavelength, index_of_refraction_media, NA_ill, NA_det, tilt_zenith, tilt_azimuth, pupil_steepness, device)` Returns a dict of the angle-fixed tensors: fyy, fxx, radial_frequencies, detection_pupil, illumination_pupil (the tilted pupil, which depends on zenith/azimuth -- "angle-fixed" means it's fixed across optimizer iters when angles are frozen). - `_compute_z_propagation(angle_optics, z_position_list, invert_phase_contrast)` Builds the propagation kernel for the current z list and returns `det_prop = detection_pupil * propagation_kernel` -- the only z-dependent piece. - `_wotf_from_split_optics(angle_optics, det_prop)` Final assembly: WOTF from the cached illumination pupil + the per-iter det_prop. Handles batched vs unbatched output shapes. `_calculate_wrap_unsafe_transfer_function` is now a thin back-compat wrapper that composes the three. Public APIs (`calculate_transfer_function`, `reconstruct`) unchanged. Why this lives in waveorder and not ops_process ------------------------------------------------ The angle/z factoring is a property of the optics math, not the OPS recipe. Any consumer that holds zenith / azimuth / NA fixed across optimizer iterations on z benefits -- not just OPS. Specifically: - OPS tilt-recon (FREEZE_ANGLES=1 recipe): builds angle optics once per position, re-calls `_compute_z_propagation` per Adam/Newton iter with the current z. Saves ~50% of the per-iter optics build cost, which is a non-trivial fraction of total per-iter wall. - Future autofocus / focus-sweep workloads: same shape. Tests (CPU) ----------- - `test_thin_3d_angle_z_split_composes_to_wrap_unsafe` -- bit-identical legacy output. - `test_thin_3d_angle_optics_cached_across_z_changes` -- cached angle optics + per-iter z propagation matches the legacy single-call path across three different z lists (the FREEZE_ANGLES workflow). - `test_thin_3d_angle_optics_batched_tilt` -- batched (B,) tilt angles produce the same split output as legacy. All 6 thin_3d tests + 10 phase_thick_3d tests pass on CPU; 2 CUDA-gated phase_thick_3d tests skipped on the login node. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…t-recon Public consumer-facing API on top of the angle/z optics split (commit f49ab6b). Builds the angle-fixed optics once at construction; each call to `transfer_functions(z_position_list)` rebuilds only the z-dependent propagation kernel and composes the WOTF. Drop-in replacement for the legacy single-shot `calculate_transfer_function` from inside the OPS optimizer hot loop: cache = CachedTiltOptics( yx_shape=..., yx_pixel_size=..., wavelength_illumination=..., index_of_refraction_media=..., numerical_aperture_illumination=..., numerical_aperture_detection=..., tilt_angle_zenith=..., tilt_angle_azimuth=..., # FROZEN device="cuda", ) for z_iter in optimizer.iters: z_positions = (z_idx + z_p.mean()) * z_pixel_size Hu, Hp = cache.transfer_functions(z_positions) # apply_inverse_transfer_function(...) using Hu, Hp Output bit-identical to fresh single-shot `_calculate_wrap_unsafe_transfer_function` (validated by two new tests). The cache is single-position; callers create one per position. Per-iter savings depend on the relative cost of building the angle half vs. the z half + the inverse-TF FFT. For OPS subtile sizes (typically ~256x256) the angle half is a non-trivial fraction of the per-iter optics build, so this pays back over 3-8 optimizer iterations. Tests ----- - `test_cached_tilt_optics_matches_legacy` -- single-shot equivalence. - `test_cached_tilt_optics_reusable_across_z_iterations` -- the actual FREEZE_ANGLES workflow: re-call with different z lists, bit-identical to legacy fresh calls. All 8 thin_3d tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…agonal Newton)
Adds a Newton-method backend to `optimize_reconstruction`. For each
free parameter, computes the first and second derivatives of the
scalar loss via `torch.autograd.grad` and takes the LM-damped step:
step = -grad / max(hessian, damping)
with a max-step cap. The Hessian is the per-parameter diagonal
(second derivative w.r.t. that parameter alone); for batched ``(B,)``
parameters and a loss that factorizes per tile this is the exact
per-tile second derivative, off-diagonal entries are zero by
independence.
`optimizable_params` semantics for ``"newton"``:
- ``init`` -- initial value (scalar or per-tile tensor; same shape
rules as Adam).
- ``lr`` -- LM damping floor AND max-step cap.
Frozen params (``lr == 0``) follow the same convention as the
gradient path: passed to ``reconstruct_fn`` but not updated.
Why Newton, for the FREEZE_ANGLES tilt-recon use case
-----------------------------------------------------
The OPS tilt-recon loop freezes zenith/azimuth and refines only z
around a warmstart-map init. The loss surface near a good init is
dominated by the local quadratic; Newton lands at the minimum in
2-3 iterations vs Adam's 5-8. Each Newton iter costs one extra
`autograd.grad` call (the Hessian) on top of the standard forward +
backward. Net: per-position iter count drops ~2x.
Already prototyped in `ops_process.reconstruct_tilt_corrected` gated
by `OPS_TILT_OPTIMIZER=newton`. This commit moves it upstream so any
waveorder consumer can opt in via `method="newton"`.
Tests
-----
- `test_newton_converges_on_quadratic` -- 1-iter convergence on exact
quadratic.
- `test_newton_batched_independent_quadratics` -- B independent
quadratic problems, each tile lands at its own target in 5 iters.
- `test_newton_frozen_axis_does_not_move` -- lr=0 param stays put.
- `test_newton_per_tile_init_tensor` -- per-tile tensor init works,
same shape rules as Adam path.
- `test_newton_all_frozen_raises` -- degenerate "all frozen" config
rejected, consistent with Adam path.
Full test sweep: 107 passed across optim/ and models/, 2 CUDA-gated
skipped on the login node.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
For the (s=2, Z) transfer-function matrix M in isotropic_thin_3d, the SVD-based inverse filter U Σ_reg Vh equals (M Mᴴ + λI)⁻¹ @ M via the thin-SVD identity (M Mᴴ = U Σ² Uᴴ for orthonormal Vh rows). Since (M Mᴴ + λI) is just 2×2 Hermitian PD, the inverse has a closed form (1/det · [[d,-c],[-c.conj(),a]]) — no SVD, no eigendecomp. Microbench on H200, N=115k complex64 (2, 21) matrices: torch.linalg.svd + einsum: 22.3 ms / call closed-form 2×2: 1.24 ms / call (18× faster) Pearson(inv_svd, inv_cf): 0.99999994 max abs diff: 1.87e-7 Full pipeline validation on ops0154 well A/1 (148 positions, NAdam 3 iters, 2×H200): 2D recon stage 4.08 s/pos → 0.52 s/pos (7.9×), total wall 7:55 → 5:57. Phase Pearson vs NAdam-6 reference: median 0.9983, min 0.9923, all 148 positions ≥ 0.99 — bit-identical to the SVD path. Gated by WAVEORDER_FAST_2D_TIKHONOV=1 env var. Only fires in no-grad mode (autograd path uses the use_svd=False norm-based decomposition, which is a different mathematical approximation that assumes channel independence).
…-skip API
New module ``waveorder.models.phase_thick_3d_tilt`` providing:
- ``optimize_subtile_tilt_params(...)`` — batched NAdam optimizer over
per-subtile (zenith, azimuth, z_offset) tilt parameters, using
``isotropic_thin_3d.reconstruct`` as the forward model. Internally
groups subtiles by focus offset and shape so the forward TF is
computed once per group. Supports ``freeze_axes=("zenith","azimuth")``
for the 1-D z-only path that's significantly faster when the
warmstart map provides reliable angle estimates.
- ``warmstart_params`` + ``skip_optim_if_warmstart`` kwargs — the
algorithm hook for caller-side skip-opt / T-cache. When set,
``optimize_subtile_tilt_params`` bypasses the NAdam loop and returns
the warmstart verbatim. The caller (e.g. ops_process) owns the skip
decision; the library just honors it.
- ``radial_blend_zenith_init(...)`` — pure utility for the validated
zen_blend recipe (smooth radial ramp from 0 at well center to the
per-subtile formula value at the edge). Used for low-NA tilt-recon
on track-style FOVs.
- ``TiltOptimResult`` dataclass — explicit result type with per-subtile
outputs, final loss, iteration count, and a ``skipped`` flag.
Tests cover both the radial-blend utility and the optimizer (synthetic
recovery on CPU/CUDA, frozen-axis behavior, warmstart-skip roundtrip,
shape-check error path).
Algorithm body lifted from PR mehta-lab#99's ``_gpu_optimize_tilt_params`` in
royerlab/ops_process. Empirically validated this session: median phase
Pearson 0.994 vs PROD on ops0154 pheno (7035 positions), 0.988 on
ops0154 track. Per-position compute 8-12× faster than the vanilla
NAdam(5,15)/NAdam(10,25) recipes when paired with the closed-form 2×2
Tikhonov inverse (already in this branch).
This is the headline new public API for the tilt-recon waveorder PR.
The corresponding ops_process adapter PR (to be opened against
royerlab/ops_process main) will replace PR mehta-lab#99's monolithic
``_gpu_optimize_tilt_params`` body with a call to this function.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Headline
This is a refactor of the optimizations from https://github.com/royerlab/ops_process/pull/99. Optimizations have been pushed down to the waveorder level in order to make them more generically applicable to other projects. The idea is to offer this as a starting point for Shalin's 3-month "Scaling WaveOrder 3.0 to Biohub data" project (proposal dated 2026-05-29). That project's 20× target needs a bit more work, but this achieves roughly 4-16× per-GPU and 6.6-6.8× wallclock on ops0154 — at PROD cell-count parity across all wells (track ±0.18%, pheno ±0.05%) once the OPS adapter applies the meniscus override (PR3).
WaveOrder-on-OPS tilt-recon is 4-16× more GPU-efficient per cell than the vanilla per-position SLURM-array path (16× on pheno from
skip_opt + warmstart map, 4× on track since ~40% of positions run the full 25-iter PROD recipe via the meniscus override to hit cell-count parity). Pipeline wallclock speedups are larger because we additionally amortize SLURM-array dispatch via INPROC and run more GPUs in parallel.Performance — vs vanilla per-position SLURM-array on ops0154
Test data: 7,035 pheno positions × 1 timepoint × 3 wells; 296 track positions × 2 timepoints × 2 wells.
The Vanilla GPU-h numbers come from a direct per-position SLURM-array benchmark (100 pheno tasks, 50 track tasks, 1 GPU per task on H100/H200/Blackwell). Per-position median wall × position count gives the theoretical compute floor of the vanilla path.
Per-GPU speedup is the algorithmic compute-per-cell improvement, independent of parallelism. Pheno gains more (16×) because
warmstart_map + skip_optskips NAdam entirely on most positions; track gains less (4×) because the meniscus override atr/r_max > 0.70puts ~40% of positions on the full 25-iter PROD recipe to hit cell-count parity.Pipeline wallclock includes additional benefits of INPROC dispatch (eliminates SLURM-array per-task startup tax — PROD's 5h30m track wall was dominated by scheduling overhead, not compute) and running multiple GPUs in parallel.
Multi-timepoint amortization. The numbers above are for ops0154, where pheno has 1 timepoint per position and track has 2. The T-cache (memoize optimized tilt params per position — first call runs full optim, subsequent T's at the same position skip the 2D Adam phase) means per-position cost drops as more timepoints are added. Track's 4× at N=2 already includes ~2× from T-cache; many-T workloads (e.g. a LiveScreen-style experiment) will see substantially larger savings.
Cell-count parity vs PROD
Pheno: PASS, all 3 wells within ±0.05% of PROD on ops0154 cell-count verify (-0.018% / -0.026% / -0.022%). Acceptance ±0.5%.
Track: PASS, all 4 well-timepoints within ±0.18% of PROD with the OPS-side meniscus override at
OPS_TILT_MENISCUS_R_THRESHOLD=0.70:The track cell-count residual at edge positions was traced to the speed recipe (
zen_blend+ frozen angles + radius-scaled 3-25 iters + T-cache) diverging from PROD's full 10-25-iter unfrozen NAdam. The OPS adapter PR (royerlab/ops_process#104) detects meniscus positions (r/r_max > T) and routes them to the PROD recipe (full 25-iter NAdam, no zen_blend, no T-cache, no warmstart skip). Threshold sweep:0.70is the highest threshold that passes ±0.5% — the optimum (lower thresholds add compute without improving accuracy).A five-way env-flag bisect on meniscus positions (toggling
WAVEORDER_FAST_2D_TIKHONOV, GPU-shared-optics, batched-vs-per-subtile NAdam — and including a "vanilla waveorder 3.0.2 + ops_process main" reference run) showed all variants within ±0.02% of each other in meniscus-band cell counts. So none of this PR's three numerical optimizations (closed-form Tikhonov, GPU-resident optics, batched NAdam) is the cause of the residual — the residual was purely the recipe choice at edge positions, which the OPS-side override fixes.Tilt-recon quality examples — best / median / worst per branch
A/1 well, T=0. Pearson(ours vs PROD) computed per FOV on the 2D phase output. Each panel: PROD (top) / ours (middle) / abs-diff (bottom). Images live in the OPS adapter PR's
pr_artifacts/.Track (148 positions, threshold=0.70, Pearson: min 0.941, median 0.972, max 1.000):
Pheno (2,345 positions in A/1, Pearson: min 0.937, median 0.979, max 0.994):
The "worst" panels show disagreement concentrated at high-contrast cell-edge pixels — biological structure is preserved across all panels.
What's in this PR
Six pieces, transfer cleanly to the upstream project:
waveorder.models.phase_thick_3d_tilt.optimize_subtile_tilt_params— batched per-subtile NAdam over N subtiles × 3 params (z_offset, zenith, azimuth) each. Existingoptimize_reconstructionis structurally single-parameter-set; the batched-subtile shape is genuinely new API.waveorder.models.phase_thick_3d_tilt.radial_blend_zenith_init— radial blend of per-subtile zen init from a baseline formula. Used for low-NA track tilt to prevent zen=0 init from drifting to the wrong focal plane on well-edge positions.GPU-resident shared optics —
_compute_shared_optics(device=...)builds propagation kernel + Green's function on-device instead of CPUtorch.exp+.to(device). Verified 1.79× on ops0154 pheno full-scale (89 min → 49.7 min on Blackwell 8-GPU). GPU util 13% → 78%.Closed-form 2×2 Tikhonov inverse in
isotropic_thin_3d.reconstruct— replaces SVD path with a closed-form Hermitian inverse. Mathematically equivalent, ~18× faster per call. Bit-identical Pearson on ops0154.CachedTiltOptics— angle-fixed optics computed once per well and reused whenfreeze_axes=("zenith","azimuth"). ~2× speedup on track-tilt where angles are frozen post-calibration.LM-damped diagonal Newton added as
method='newton'option tooptimize_reconstruction.Tests
tests/models/test_phase_thick_3d_tilt.py— 8 tests covering shape invariants, synthetic recovery, frozen-axis behavior, warmstart-skip roundtrip, and the radial-blend zero-r edge case. All pass on CPU and CUDA.Anti-patterns deliberately avoided (per OPS memory)
Open caveats
ops_processPR (per-well SLURM fan-out, universal warmstart map, meniscus overrideOPS_TILT_MENISCUS_R_THRESHOLD) bake in OPS-specific assumptions and are likely candidates for rework when the WaveOrder 3.0 engineer designs the generalizable layer. Those are intentionally kept out of this PR.Draft
This is a draft while the OPS-side ops_process / ops_monorepo PRs that pin to this branch land. Numbers above are reproducible from
mark/ops0154_cell_count_verify(this-stack run at threshold 0.85),mark/ops0154_bisect_men_0p{50,60,70,75}(threshold sweep),mark/ops0154_bisect_vanilla(PROD-stack reference),mark/ops0154_vanilla_pheno_benchmark/mark/ops0154_vanilla_track_benchmark(vanilla extrapolations) on the Biohub HPC.🤖 Generated with Claude Code