Skip to content

n_workers kwarg for FilterRecording and CommonReferenceRecording#4564

Open
galenlynch wants to merge 8 commits intoSpikeInterface:mainfrom
galenlynch:perf/parallel-filter-cmr
Open

n_workers kwarg for FilterRecording and CommonReferenceRecording#4564
galenlynch wants to merge 8 commits intoSpikeInterface:mainfrom
galenlynch:perf/parallel-filter-cmr

Conversation

@galenlynch
Copy link
Copy Markdown
Contributor

@galenlynch galenlynch commented Apr 24, 2026

I was attempting to visualize my recordings as they would be seen by Kilosort4 in a spikeinterface pipeline, but the preprocessing performance was dreadfully slow for 90 minute recordings. I have since realized that this library primarily uses a batch processing model with TimeSeriesChunkExecutor, and parallelizes over chunks of data. However, TimeSeriesChunkExecutor is difficult to use if you just want to see a small stretch of a recording after preprocessing, which from my understanding requires preprocessing the entire recording, or at least saving to disk. Other workflows might also benefit from faster direct access to the data, instead of using the batch process API.

Because of this design choice, inner extractors are written to be single-threaded. This PR introduces extractor-level parallelism for FilterRecording and CommonReferenceRecording, which was intended to speed up get_traces calls that are not inside TimeSeriesChunkExecutor. The high-level takeaway is that on a 384x1M sample data segment, FilterRecording is now 2.8 times faster, and CommonReferenceRecording is 5.4 times faster. Those times reflect the entire extractor with overhead; the speedup for the core computation is 2.9x for scipy.sosfiltfilt and 10.6x for np.median. The core algorithm was not changed for either, and the multi-threading is opt-in by raising n_workers past its default of 1.

The companion PR #4563 was also intended to speed up get_traces calls by addressing the lion's share of the processing time: PhaseShiftRecording. However, the majority of the speedup in that PR is due to an algorithmic improvement and not simply extractor-level parallelism, which is why I separated it.

When these PRs are combined and tested on my data, the entire get_traces() call with phase shifting, CMR, and filtering is 13.1 times faster.


But how do these changes compose with TimeSeriesChunkExecutor? The good news is that adding extractor-level parallelism to TimeSeriesChunkExecutor can actually improve performance, and it might be a nice knob to have for RAM-conscious users that nonetheless have modern CPUs with many cores. See the table below for band pass + common reference filter chains, with 'outer' meaning TimeSeriesChunkExecutor n_jobs and 'inner' meaning extractor-level n_workers:

Setup: chunk_duration="1s" (SI default), different splits of a ~24-thread compute budget on the BP+CMR pipeline, per-caller-thread pools:

Budget Config Time Notes
24 threads outer=24, inner=1 each 1.54 s clean, minimum thread count
192 threads outer=24, inner=8 each 1.42 s absolute peak, oversubscribed
24 threads outer=8, inner=3 each 1.53 s 8 outer, tied with outer=24 inner=1; ~⅓ RAM
12 threads outer=12, inner=1 each 1.59 s ~½ RAM of outer=24
12 threads outer=6, inner=2 each 1.75 s ~¼ RAM of outer=24
12 threads outer=4, inner=3 each 1.92 s ~⅙ RAM of outer=24
12 threads outer=1, inner=12 each 4.31 s single caller — sync overhead dominates

We can split this up for each extractor, with CRE meaning TimeSeriesChunkExecutor:

band pass specifically (inner pool = 8, matching CRE n_jobs=8):

Config Time Speedup Parallelism axis
stock, CRE n=1 (baseline) 7.42 s 1.00×
stock, CRE n=8 thread 1.40 s 5.29× outer only
n_workers=8, CRE n=1 2.39 s 3.04× inner only
n_workers=8, CRE n=8 thread 1.24 s 6.00× both

common mode reference specifically (inner pool = 16, exceeds CRE n_jobs=8):

Config Time Speedup Parallelism axis
stock, CRE n=1 (baseline) 3.93 s 1.00×
stock, CRE n=8 thread 0.62 s 6.34× outer only
n_workers=16, CRE n=1 0.86 s 4.57× inner only
n_workers=16, CRE n=8 thread 0.33 s 11.83× both

Compatibility

  • No default behavior changes. n_workers=1 preserves existing semantics exactly.
  • Round-trip dumpability. _kwargs dicts updated on both preprocessors; save() / load() round-trip the new kwargs correctly.
  • No new deps. Uses stdlib concurrent.futures.ThreadPoolExecutor, threading, weakref.
  • Propagates through the bandpass_filter, highpass_filter, filter, notch_filter, common_reference wrapper functions via **filter_kwargs.
  • Long-running processes safe: per-caller pools are cleaned up when the calling thread is GC'd.

Companion PR

An independent companion PR #4563 adds a sinc-FIR alternative to PhaseShiftRecording with ~100× per-stage speedup while reducing memory. Combined, they give 13–20× on a typical PhaseShiftRecording → HighpassFilterRecording → CommonReferenceRecording chain for direct get_traces() callers, or ~3× on top of existing CRE parallelism.

@galenlynch
Copy link
Copy Markdown
Contributor Author

Test failures are unrelated to this PR

@alejoe91 alejoe91 added preprocessing Related to preprocessing module concurrency Related to parallel processing performance Performance issues/improvements labels May 4, 2026
@alejoe91
Copy link
Copy Markdown
Member

alejoe91 commented May 5, 2026

Thanks @galenlynch

We have discussed this internally and we think that having per processor thread parallelization could be a common and useful use case.

Implementation-wise, we think we could move the pool creation to core/job_tools.py and make it more general on what functions it applies. Then each extractor could implement a get_traces_multi_thread, which by default would fall back to get_traces, but could be overridden by individual preprocessors (like filter and common_ref). We already have a parameters in the job_kwargs called max_threads_per_worker. This could be use to automatically enable multi-threaded trace retrieval, avoiding to expose the num_workers at the init level. I think this is a better option, since the same class could support different paradigms (e.g., for visualization: 1 process + multi-threaded; for batch processing: N processes + 1 thread) without the need for reinstantiation.

What do you think? We can help with the implementation side :)

@h-mayorquin
Copy link
Copy Markdown
Collaborator

The initial thread is really hard to read. I am fine with LLM contributions but I think they should be curated somehow.

@galenlynch
Copy link
Copy Markdown
Contributor Author

galenlynch commented May 8, 2026

@h-mayorquin I can clean up the initial thread. fwiw I did a lot of testing of this PR and in my hands found a massive speedup, particularly when using get_traces for viz which didn't expose parallelism vs TimeSeriesChunkExecutor which seems to be the library-preferred model of parallelism but doesn't work for viz workflows. The performance story is complicated because of TimeSeriesChunkExecutor, and requires explanation as a result.

Edit: original message rewritten.

@galenlynch
Copy link
Copy Markdown
Contributor Author

galenlynch commented May 8, 2026

@alejoe91 centralizing threading control seems like a great idea! Adding get_traces_multi_thread per extractor seems like the right approach to me.

One thing worth pointing out: as far as I understand max_threads_per_workers currently controls threadpoolctl.threadpool_limits, which controls parallelism in BLAS/OpenMP:

with threadpool_limits(limits=self.max_threads_per_worker):
return self.func(segment_index, start_frame, end_frame, self.worker_dict)

However, scipy.sosfiltfilt and np.median are both single-threaded C code which do not use BLAS, which is why this PR uses explicit Python-level fan-out using ThreadPoolExecutor. That's a different axis of threading than what is controlled by max_threads_per_worker in TimeSeriesChunkExecutor via WorkerFuncWrapper. There is yet another axis of parallelism used by the sister PR #4563 which is numba prange. Having the same number control all of these could open the door to thread oversubscription. Maybe each BaseRecordingSegment's get_traces_multi_thread could decide how to spend that budget?

Maybe one way to do this would be to add building blocks in core/job_tools.py:

# Per-caller-thread pool, what the current PR builds inline in filter.py
def get_inner_pool(max_threads: int) -> ThreadPoolExecutor:
    """Per-caller-thread ThreadPoolExecutor (WeakKeyDictionary keyed by Thread)."""
    ...

@contextmanager
def thread_budget(max_threads: int, *, blas: bool = False, numba: bool = False):
    """Cap the named runtimes for the duration of the context."""
    ...

Dispatch in BaseRecording:

class BaseRecording:
    def get_traces_multi_thread(
        self, *, segment_index=None, start_frame=None, end_frame=None,
        channel_ids=None, max_threads=None,
    ):
        """Like get_traces, but allowed to use up to max_threads
        threads internally. Default falls back to get_traces."""
        if max_threads is None:
            max_threads = get_global_job_kwargs()["max_threads_per_worker"]
        if max_threads <= 1:
            return self.get_traces(...)
        rs = self.segments[segment_index]
        return rs.get_traces_multi_thread(start_frame, end_frame,
                                          channel_indices, max_threads)

And different extractors could choose how to spend the budget:

# Filter (this PR, wants ThreadPoolExecutor):
class FilterRecordingSegment:
    def get_traces_multi_thread(self, start, end, channel_indices, max_threads):
        traces = self._fetch_with_margin(start, end, channel_indices)
        pool = get_inner_pool(max_threads)
        return self._apply_sos_parallel(traces, pool)  # GIL-released scipy

# CMR (this PR, wants ThreadPoolExecutor):
class CommonReferenceRecordingSegment:
    def get_traces_multi_thread(self, start, end, channel_indices, max_threads):
        traces = self._fetch(start, end, channel_indices)
        pool = get_inner_pool(max_threads)
        return self._parallel_reduce_axis1(traces, pool)

# PhaseShift FIR (sister PR, numba prange only):
class PhaseShiftRecordingSegment:
    def get_traces_multi_thread(self, start, end, channel_indices, max_threads):
        traces = self._fetch(...)
        with thread_budget(max_threads, numba=True):
            return _sinc_fir_kernel_tc(traces, self._kernels)

# Hypothetical BLAS-heavy segment (example, BLAS threading only):
class WhitenRecordingSegment:
    def get_traces_multi_thread(self, start, end, channel_indices, max_threads):
        traces = self._fetch(...)
        with thread_budget(max_threads, blas=True):
            return self._W @ traces.T

galenlynch and others added 6 commits May 8, 2026 16:01
Adds opt-in intra-chunk thread-parallelism to two preprocessors:
channel-split sosfilt/sosfiltfilt in FilterRecording, time-split
median/mean in CommonReferenceRecording.  Default n_workers=1 preserves
existing behavior.

Per-caller-thread inner pools
-----------------------------
Each outer thread that calls ``get_traces()`` on a parallel-enabled segment
gets its own inner ThreadPoolExecutor, stored in a ``WeakKeyDictionary``
keyed by the calling ``Thread`` object.  Rationale:

* Avoids the shared-pool queueing pathology that would occur if N outer
  workers (e.g., TimeSeriesChunkExecutor with n_jobs=N) all submitted
  into a single shared pool with fewer max_workers than outer callers.
  Under a shared pool, ``n_workers=2`` with ``n_jobs=24`` thrashed at
  3.36 s on the test pipeline; per-caller pools: 1.47 s.
* Keying by the Thread object (not thread-id integer) avoids the
  thread-id-reuse hazard: thread IDs can be reused after a thread dies,
  which would cause a new thread to silently inherit a dead thread's
  pool.
* WeakKeyDictionary + weakref.finalize ensures automatic shutdown of
  the inner pool when the calling thread is garbage-collected.  The
  finalizer calls ``pool.shutdown(wait=False)`` to avoid blocking the
  finalizer thread; in-flight tasks would be cancelled, but the owning
  thread submits+joins synchronously, so none exist when it exits.

When useful
-----------
* Direct ``get_traces()`` callers (interactive viewers, streaming
  consumers, mipmap-zarr tile builders) that don't use
  ``TimeSeriesChunkExecutor``.
* Default SI users who haven't tuned job_kwargs.
* RAM-constrained deployments that can't crank ``n_jobs`` to core count:
  on a 24-core host, ``n_jobs=6, n_workers=2`` gets within 8% of
  ``n_jobs=24, n_workers=1`` at ~1/4 the RAM.

Performance (1M × 384 float32 BP+CMR pipeline, 24-core host, thread engine)
---------------------------------------------------------------------------
  === Component-level (scipy/numpy only) ===
  sosfiltfilt serial → 8 threads:   7.80 s →  2.67 s (2.92x)
  np.median serial   → 16 threads:  3.51 s →  0.33 s (10.58x)

  === Per-stage end-to-end (rec.get_traces) ===
  Bandpass (5th-order, 300-6k Hz): 8.59 s →  3.20 s (2.69x)
  CMR median (global):             4.01 s →  0.81 s (4.95x)

  === CRE outer × inner Pareto, per-caller pools ===
  outer=24, inner=1 each:          1.54 s  (100% of peak)
  outer=24, inner=8 each:          1.42 s  (108% of peak; oversubscribed)
  outer=12, inner=1 each:          1.59 s  (97%, ~1/2 RAM of outer=24)
  outer=6,  inner=2 each:          1.75 s  (92%, ~1/4 RAM of outer=24)
  outer=4,  inner=6 each:          1.83 s  (87%, ~1/6 RAM with 24 threads)

Tests
-----
New ``test_parallel_pool_semantics.py`` verifies the per-caller-thread
contract: single caller reuses one pool; concurrent callers get distinct
pools.  Existing bandpass + CMR tests still pass.

Independent of the companion FIR phase-shift PR (perf/phase-shift-fir);
the two can land in either order.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The per-caller-thread pool dict on FilterRecordingSegment and
CommonReferenceRecordingSegment is keyed by Thread object via a
WeakKeyDictionary.  Across os.fork(), Python re-uses the calling thread's
identity in the child, so a child's first lookup returns the parent's
ThreadPoolExecutor — whose worker OS threads do not exist in the child.
The child's first submit() then blocks indefinitely.

Reproducer: parent calls get_traces() (lazily creating the pool), then
runs save() / write_binary_recording() with n_jobs > 1 and the default
fork start method on Linux.  Child workers hang in Sl state with 0% CPU.

Fix: stash os.getpid() alongside the pool dict.  In _get_pool, if the
current pid differs, rebuild the dict and lock from scratch before
proceeding.  Pickling (mp_context="spawn"/"forkserver") goes through
__reduce__ → __init__ and gets fresh state already, so this guard is
specifically for the fork copy-of-memory path.

Adds a regression test that pre-warms the pool, forks via mp.fork
context, and asserts get_traces() in the child completes within 30 s.
Without the guard the test deadlocks; with it, it passes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous dispatch had each parallel worker return ``(c0, c1, block)``
tuples; the calling thread then allocated the output array and copied each
block into place.  That post-collection allocate-and-copy is wasted work
since the channel/time slices are non-overlapping — workers can write
directly into a pre-allocated output.

Measured on a (30000, 384) float32 chunk with sosfiltfilt and
n_workers=5:

  pattern                              wall (ms)   speedup
  E. sequential                          173.89      1.00×
  A. submit + collect + alloc + copy      75.66      2.30×   (current)
  B. pre-alloc, write in place            60.51      2.87×   (this PR)
  C. pool.map, write in place             63.55      2.74×
  D. manual threading.Thread              64.76      2.69×

So we save ~15 ms wall per `_apply_sos` call (likewise for
`_parallel_reduce_axis1`) by dropping the redundant copy.  Ideal 5×
scaling would be 34.78 ms; the remaining gap to ideal is the GIL-held
Python wrapper inside scipy's sosfiltfilt — pattern doesn't matter there
(B/C/D are all within noise), so we keep the simpler submit/result form.

Same pattern applied to common_reference._parallel_reduce_axis1.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Switch from "one chunk per worker" to "many small chunks dispatched FIFO
to a fixed-size pool", sized so the per-chunk input fits L2 (~1.5 MB) and
N_workers active chunks fit shared L3.  All workers tend to cluster in
the same time region of the input at any moment, so shared L3 absorbs
the data once instead of N_workers independent streams competing for
DRAM.

Empirical wins on a 24-core x86_64 host with n_workers=16, fp32:

  T (samples)  OLD (1 chunk/wkr)  NEW   speedup
  30_000        44.6 ms           24.8  1.80x
  524_288      397.8 ms          358.6  1.11x

The 1.80x at T=30k (SI default chunk_duration='1s' at fs=30 kHz) is the
bigger win: the OLD min_block=8192 only used T // 8192 = 3 effective
workers at that T, leaving 13 idle.  The NEW scheme dispatches enough
chunks (~30 at T=30k) for all 16 workers to do useful work.

The 1.11x at T=524k is smaller because the OLD code already used all
workers there; the NEW scheme just shifts the cache pattern.  Direct
numpy bench (no SI plumbing) shows ~1.4x at this T; SI's get_traces
overhead dilutes that to 1.11x end-to-end.

Diminishing returns past ~16 chunks/worker — dispatch overhead starts
to compete with the cache win.  The block-size formula caps total
chunks at 64 * n_workers and floors the block at 256 rows.

No new feature; same n_workers kwarg, same correctness invariants.
Existing 12 CMR + parallel-pool tests pass unchanged.
@galenlynch galenlynch force-pushed the perf/parallel-filter-cmr branch from 7c8530c to 4fbd1a6 Compare May 8, 2026 23:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

concurrency Related to parallel processing performance Performance issues/improvements preprocessing Related to preprocessing module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants