n_workers kwarg for FilterRecording and CommonReferenceRecording#4564
n_workers kwarg for FilterRecording and CommonReferenceRecording#4564galenlynch wants to merge 8 commits intoSpikeInterface:mainfrom
n_workers kwarg for FilterRecording and CommonReferenceRecording#4564Conversation
|
Test failures are unrelated to this PR |
|
Thanks @galenlynch We have discussed this internally and we think that having per processor thread parallelization could be a common and useful use case. Implementation-wise, we think we could move the pool creation to core/job_tools.py and make it more general on what functions it applies. Then each extractor could implement a What do you think? We can help with the implementation side :) |
|
The initial thread is really hard to read. I am fine with LLM contributions but I think they should be curated somehow. |
|
@h-mayorquin I can clean up the initial thread. fwiw I did a lot of testing of this PR and in my hands found a massive speedup, particularly when using Edit: original message rewritten. |
|
@alejoe91 centralizing threading control seems like a great idea! Adding One thing worth pointing out: as far as I understand spikeinterface/src/spikeinterface/core/job_tools.py Lines 631 to 632 in ec55b00 However, scipy.sosfiltfilt and np.median are both single-threaded C code which do not use BLAS, which is why this PR uses explicit Python-level fan-out using ThreadPoolExecutor. That's a different axis of threading than what is controlled by max_threads_per_worker in TimeSeriesChunkExecutor via WorkerFuncWrapper. There is yet another axis of parallelism used by the sister PR #4563 which is numba prange. Having the same number control all of these could open the door to thread oversubscription. Maybe each BaseRecordingSegment's get_traces_multi_thread could decide how to spend that budget?
Maybe one way to do this would be to add building blocks in # Per-caller-thread pool, what the current PR builds inline in filter.py
def get_inner_pool(max_threads: int) -> ThreadPoolExecutor:
"""Per-caller-thread ThreadPoolExecutor (WeakKeyDictionary keyed by Thread)."""
...
@contextmanager
def thread_budget(max_threads: int, *, blas: bool = False, numba: bool = False):
"""Cap the named runtimes for the duration of the context."""
...Dispatch in class BaseRecording:
def get_traces_multi_thread(
self, *, segment_index=None, start_frame=None, end_frame=None,
channel_ids=None, max_threads=None,
):
"""Like get_traces, but allowed to use up to max_threads
threads internally. Default falls back to get_traces."""
if max_threads is None:
max_threads = get_global_job_kwargs()["max_threads_per_worker"]
if max_threads <= 1:
return self.get_traces(...)
rs = self.segments[segment_index]
return rs.get_traces_multi_thread(start_frame, end_frame,
channel_indices, max_threads)And different extractors could choose how to spend the budget: # Filter (this PR, wants ThreadPoolExecutor):
class FilterRecordingSegment:
def get_traces_multi_thread(self, start, end, channel_indices, max_threads):
traces = self._fetch_with_margin(start, end, channel_indices)
pool = get_inner_pool(max_threads)
return self._apply_sos_parallel(traces, pool) # GIL-released scipy
# CMR (this PR, wants ThreadPoolExecutor):
class CommonReferenceRecordingSegment:
def get_traces_multi_thread(self, start, end, channel_indices, max_threads):
traces = self._fetch(start, end, channel_indices)
pool = get_inner_pool(max_threads)
return self._parallel_reduce_axis1(traces, pool)
# PhaseShift FIR (sister PR, numba prange only):
class PhaseShiftRecordingSegment:
def get_traces_multi_thread(self, start, end, channel_indices, max_threads):
traces = self._fetch(...)
with thread_budget(max_threads, numba=True):
return _sinc_fir_kernel_tc(traces, self._kernels)
# Hypothetical BLAS-heavy segment (example, BLAS threading only):
class WhitenRecordingSegment:
def get_traces_multi_thread(self, start, end, channel_indices, max_threads):
traces = self._fetch(...)
with thread_budget(max_threads, blas=True):
return self._W @ traces.T |
Adds opt-in intra-chunk thread-parallelism to two preprocessors: channel-split sosfilt/sosfiltfilt in FilterRecording, time-split median/mean in CommonReferenceRecording. Default n_workers=1 preserves existing behavior. Per-caller-thread inner pools ----------------------------- Each outer thread that calls ``get_traces()`` on a parallel-enabled segment gets its own inner ThreadPoolExecutor, stored in a ``WeakKeyDictionary`` keyed by the calling ``Thread`` object. Rationale: * Avoids the shared-pool queueing pathology that would occur if N outer workers (e.g., TimeSeriesChunkExecutor with n_jobs=N) all submitted into a single shared pool with fewer max_workers than outer callers. Under a shared pool, ``n_workers=2`` with ``n_jobs=24`` thrashed at 3.36 s on the test pipeline; per-caller pools: 1.47 s. * Keying by the Thread object (not thread-id integer) avoids the thread-id-reuse hazard: thread IDs can be reused after a thread dies, which would cause a new thread to silently inherit a dead thread's pool. * WeakKeyDictionary + weakref.finalize ensures automatic shutdown of the inner pool when the calling thread is garbage-collected. The finalizer calls ``pool.shutdown(wait=False)`` to avoid blocking the finalizer thread; in-flight tasks would be cancelled, but the owning thread submits+joins synchronously, so none exist when it exits. When useful ----------- * Direct ``get_traces()`` callers (interactive viewers, streaming consumers, mipmap-zarr tile builders) that don't use ``TimeSeriesChunkExecutor``. * Default SI users who haven't tuned job_kwargs. * RAM-constrained deployments that can't crank ``n_jobs`` to core count: on a 24-core host, ``n_jobs=6, n_workers=2`` gets within 8% of ``n_jobs=24, n_workers=1`` at ~1/4 the RAM. Performance (1M × 384 float32 BP+CMR pipeline, 24-core host, thread engine) --------------------------------------------------------------------------- === Component-level (scipy/numpy only) === sosfiltfilt serial → 8 threads: 7.80 s → 2.67 s (2.92x) np.median serial → 16 threads: 3.51 s → 0.33 s (10.58x) === Per-stage end-to-end (rec.get_traces) === Bandpass (5th-order, 300-6k Hz): 8.59 s → 3.20 s (2.69x) CMR median (global): 4.01 s → 0.81 s (4.95x) === CRE outer × inner Pareto, per-caller pools === outer=24, inner=1 each: 1.54 s (100% of peak) outer=24, inner=8 each: 1.42 s (108% of peak; oversubscribed) outer=12, inner=1 each: 1.59 s (97%, ~1/2 RAM of outer=24) outer=6, inner=2 each: 1.75 s (92%, ~1/4 RAM of outer=24) outer=4, inner=6 each: 1.83 s (87%, ~1/6 RAM with 24 threads) Tests ----- New ``test_parallel_pool_semantics.py`` verifies the per-caller-thread contract: single caller reuses one pool; concurrent callers get distinct pools. Existing bandpass + CMR tests still pass. Independent of the companion FIR phase-shift PR (perf/phase-shift-fir); the two can land in either order. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The per-caller-thread pool dict on FilterRecordingSegment and CommonReferenceRecordingSegment is keyed by Thread object via a WeakKeyDictionary. Across os.fork(), Python re-uses the calling thread's identity in the child, so a child's first lookup returns the parent's ThreadPoolExecutor — whose worker OS threads do not exist in the child. The child's first submit() then blocks indefinitely. Reproducer: parent calls get_traces() (lazily creating the pool), then runs save() / write_binary_recording() with n_jobs > 1 and the default fork start method on Linux. Child workers hang in Sl state with 0% CPU. Fix: stash os.getpid() alongside the pool dict. In _get_pool, if the current pid differs, rebuild the dict and lock from scratch before proceeding. Pickling (mp_context="spawn"/"forkserver") goes through __reduce__ → __init__ and gets fresh state already, so this guard is specifically for the fork copy-of-memory path. Adds a regression test that pre-warms the pool, forks via mp.fork context, and asserts get_traces() in the child completes within 30 s. Without the guard the test deadlocks; with it, it passes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous dispatch had each parallel worker return ``(c0, c1, block)`` tuples; the calling thread then allocated the output array and copied each block into place. That post-collection allocate-and-copy is wasted work since the channel/time slices are non-overlapping — workers can write directly into a pre-allocated output. Measured on a (30000, 384) float32 chunk with sosfiltfilt and n_workers=5: pattern wall (ms) speedup E. sequential 173.89 1.00× A. submit + collect + alloc + copy 75.66 2.30× (current) B. pre-alloc, write in place 60.51 2.87× (this PR) C. pool.map, write in place 63.55 2.74× D. manual threading.Thread 64.76 2.69× So we save ~15 ms wall per `_apply_sos` call (likewise for `_parallel_reduce_axis1`) by dropping the redundant copy. Ideal 5× scaling would be 34.78 ms; the remaining gap to ideal is the GIL-held Python wrapper inside scipy's sosfiltfilt — pattern doesn't matter there (B/C/D are all within noise), so we keep the simpler submit/result form. Same pattern applied to common_reference._parallel_reduce_axis1. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
for more information, see https://pre-commit.ci
Switch from "one chunk per worker" to "many small chunks dispatched FIFO to a fixed-size pool", sized so the per-chunk input fits L2 (~1.5 MB) and N_workers active chunks fit shared L3. All workers tend to cluster in the same time region of the input at any moment, so shared L3 absorbs the data once instead of N_workers independent streams competing for DRAM. Empirical wins on a 24-core x86_64 host with n_workers=16, fp32: T (samples) OLD (1 chunk/wkr) NEW speedup 30_000 44.6 ms 24.8 1.80x 524_288 397.8 ms 358.6 1.11x The 1.80x at T=30k (SI default chunk_duration='1s' at fs=30 kHz) is the bigger win: the OLD min_block=8192 only used T // 8192 = 3 effective workers at that T, leaving 13 idle. The NEW scheme dispatches enough chunks (~30 at T=30k) for all 16 workers to do useful work. The 1.11x at T=524k is smaller because the OLD code already used all workers there; the NEW scheme just shifts the cache pattern. Direct numpy bench (no SI plumbing) shows ~1.4x at this T; SI's get_traces overhead dilutes that to 1.11x end-to-end. Diminishing returns past ~16 chunks/worker — dispatch overhead starts to compete with the cache win. The block-size formula caps total chunks at 64 * n_workers and floors the block at 256 rows. No new feature; same n_workers kwarg, same correctness invariants. Existing 12 CMR + parallel-pool tests pass unchanged.
not needed in PR
7c8530c to
4fbd1a6
Compare
I was attempting to visualize my recordings as they would be seen by Kilosort4 in a spikeinterface pipeline, but the preprocessing performance was dreadfully slow for 90 minute recordings. I have since realized that this library primarily uses a batch processing model with
TimeSeriesChunkExecutor, and parallelizes over chunks of data. However,TimeSeriesChunkExecutoris difficult to use if you just want to see a small stretch of a recording after preprocessing, which from my understanding requires preprocessing the entire recording, or at least saving to disk. Other workflows might also benefit from faster direct access to the data, instead of using the batch process API.Because of this design choice, inner extractors are written to be single-threaded. This PR introduces extractor-level parallelism for
FilterRecordingandCommonReferenceRecording, which was intended to speed upget_tracescalls that are not insideTimeSeriesChunkExecutor. The high-level takeaway is that on a 384x1M sample data segment,FilterRecordingis now 2.8 times faster, andCommonReferenceRecordingis 5.4 times faster. Those times reflect the entire extractor with overhead; the speedup for the core computation is 2.9x forscipy.sosfiltfiltand 10.6x fornp.median. The core algorithm was not changed for either, and the multi-threading is opt-in by raisingn_workerspast its default of 1.The companion PR #4563 was also intended to speed up
get_tracescalls by addressing the lion's share of the processing time:PhaseShiftRecording. However, the majority of the speedup in that PR is due to an algorithmic improvement and not simply extractor-level parallelism, which is why I separated it.When these PRs are combined and tested on my data, the entire
get_traces()call with phase shifting, CMR, and filtering is 13.1 times faster.But how do these changes compose with
TimeSeriesChunkExecutor? The good news is that adding extractor-level parallelism toTimeSeriesChunkExecutorcan actually improve performance, and it might be a nice knob to have for RAM-conscious users that nonetheless have modern CPUs with many cores. See the table below for band pass + common reference filter chains, with 'outer' meaningTimeSeriesChunkExecutorn_jobs and 'inner' meaning extractor-leveln_workers:Setup: chunk_duration="1s" (SI default), different splits of a ~24-thread compute budget on the BP+CMR pipeline, per-caller-thread pools:
We can split this up for each extractor, with
CREmeaningTimeSeriesChunkExecutor:band pass specifically (inner pool = 8, matching CRE n_jobs=8):
common mode reference specifically (inner pool = 16, exceeds CRE n_jobs=8):
Compatibility
n_workers=1preserves existing semantics exactly._kwargsdicts updated on both preprocessors;save()/load()round-trip the new kwargs correctly.concurrent.futures.ThreadPoolExecutor,threading,weakref.bandpass_filter,highpass_filter,filter,notch_filter,common_referencewrapper functions via**filter_kwargs.Companion PR
An independent companion PR #4563 adds a sinc-FIR alternative to
PhaseShiftRecordingwith ~100× per-stage speedup while reducing memory. Combined, they give 13–20× on a typicalPhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecordingchain for directget_traces()callers, or ~3× on top of existing CRE parallelism.