Skip to content

Training improvements - sparse#694

Open
schaugf wants to merge 3 commits into
training-improvementsfrom
training-improvements-sparse
Open

Training improvements - sparse#694
schaugf wants to merge 3 commits into
training-improvementsfrom
training-improvements-sparse

Conversation

@schaugf
Copy link
Copy Markdown
Collaborator

@schaugf schaugf commented May 14, 2026

  • Sparse inference support: Adds --sparse flag to test_inference.py,
    run_test_inference.sh, and beaker/submit_test_inference.sh. When set, the pipeline uses
    SparseGraphDataset (biased branch/proximity sampling) instead of DenseGraphDataset.
    Tunable via --branch_radius (default 25 μm) and --proximity_radius (default 15 μm). The
    Beaker submit script auto-sets --variant=sparse so results land in mergedetection_sparse/
    on S3 and don't clobber the dense baseline.
  • Normalization mixin refactor: Extracts the per-patch normalization logic into a
    _PerPatchNormMixin shared by both PatchNormalizedDenseGraphDataset and the new
    PatchNormalizedSparseGraphDataset. Eliminates the duplicate constructor kwargs pattern via
    dataset_kwargs.
  • GCS credential fix: Beaker jobs launched via command: bypass ENTRYPOINT, so
    entrypoint.sh never ran. run_test_inference.sh now writes the GCS token directly. Both
    scripts emit explicit status messages so a missing secret is visible in logs rather than
    surfacing later as a cryptic gcs_token.json not found error.
  • FP site cap removed: Removed the 200-sample cap on false-positive sites in
    get_site_predictions; all FP sites are now scored.

schaugf and others added 3 commits May 12, 2026 14:05
Replaces the placeholder SparseGraphDataset with the prototype from
feat/sparse-inference-sampling (5b53700), ported to the training-improvements
SkeletonGraph API (get_branchings vs. branching_nodes).

SparseGraphDataset now inherits from DenseGraphDataset and overrides only
_generate_batch_nodes / estimate_iterations. Node selection is gated by
sparse_sampling.compute_interesting_nodes, which picks the union of:
  - nodes within branch_radius graph-distance of any branching node
  - nodes within proximity_radius euclidean distance of a node in another
    connected component

Everything else in merge_inference.py is unchanged from training-improvements.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds the core sparse sampling path for merge inference, enabling inference datasets to focus on branch-adjacent and inter-component-proximity nodes instead of dense traversal.

Changes:

  • Adds compute_interesting_nodes() for branch/proximity node selection.
  • Refactors SparseGraphDataset to inherit dense batching/prefetch behavior while overriding node selection.
  • Adds sparse sampling parameters for branch and proximity radii.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
src/neuron_proofreader/merge_proofreading/sparse_sampling.py Adds sparse node-selection helper based on branch distance and component proximity.
src/neuron_proofreader/merge_proofreading/merge_inference.py Imports sparse sampling and implements SparseGraphDataset using dense inference batching with sparse node filtering.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

queue.append((j, dist_j))

# Inter-component proximity nodes
if graph.kdtree is None:
Comment on lines +562 to +568
class SparseGraphDataset(DenseGraphDataset):
"""
Inference dataset that samples only nodes near branch points or near
other axons, skipping long isolated axon segments. Inherits the
image-prefetch and per-node feature extraction from DenseGraphDataset
and overrides only the node-selection logic.
"""
Comment on lines +603 to +605
self.search_mode = "biased_sparse"
self.branch_radius = branch_radius
self.proximity_radius = proximity_radius
Comment on lines +664 to +665
step = max(1, self.step_size)
return max(1, len(self._interesting_nodes) // step)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants