Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 65 additions & 35 deletions src/neuron_proofreader/merge_proofreading/merge_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from neuron_proofreader.machine_learning.point_cloud_models import (
subgraph_to_point_cloud,
)
from neuron_proofreader.merge_proofreading.sparse_sampling import (
compute_interesting_nodes,
)
from neuron_proofreader.utils import (
geometry_util,
img_util,
Expand Down Expand Up @@ -556,80 +559,107 @@ def estimate_iterations(self):
return int(length / self.step_size)


class SparseGraphDataset(GraphDataset):
class SparseGraphDataset(DenseGraphDataset):
"""
Inference dataset that samples only nodes near branch points or near
other axons, skipping long isolated axon segments. Inherits the
image-prefetch and per-node feature extraction from DenseGraphDataset
and overrides only the node-selection logic.
"""
Comment on lines +562 to +568

def __init__(
self,
graph,
img_path,
patch_shape,
batch_size=16,
branch_radius=25.0,
brightness_clip=300,
is_multimodal=False,
min_search_size=0,
prefetch=128,
proximity_radius=15.0,
segmentation_path=None,
step_size=10,
subgraph_radius=100,
use_new_mask=False
use_new_mask=False,
):
# Call parent class
super().__init__(
graph,
img_path,
patch_shape,
batch_size=batch_size,
brightness_clip=brightness_clip,
is_multimodal=is_multimodal,
min_search_size=min_search_size,
prefetch=prefetch,
segmentation_path=segmentation_path,
step_size=step_size,
subgraph_radius=subgraph_radius,
use_new_mask=use_new_mask
use_new_mask=use_new_mask,
)

# Instance attributes
self.search_mode = "branching_points"

def _generate_batches_from_component(self):
pass
self.search_mode = "biased_sparse"
self.branch_radius = branch_radius
self.proximity_radius = proximity_radius
Comment on lines +603 to +605
self._interesting_nodes = compute_interesting_nodes(
graph,
branch_radius=branch_radius,
proximity_radius=proximity_radius,
)

def _generate_batch_nodes(self, root):
"""
Iterates the connected component containing "root" via DFS at the
same "step_size" cadence as DenseGraphDataset, but emits only nodes
pre-selected by "compute_interesting_nodes". Long, isolated axon
segments contribute zero samples.
"""
nodes = list()
patch_centers = list()
for i, j in nx.dfs_edges(self.graph, source=root):
# Check if starting new batch
self.distance_traversed += self.graph.dist(i, j)
if len(patch_centers) == 0 and self.graph.degree[i] > 2:
root = i
nodes.append(i)
patch_centers.append(self.graph.get_voxel(i))

# Check whether to yield batch
is_node_far = self.graph.dist(root, j) > 256
is_batch_full = len(patch_centers) == self.batch_size
if is_node_far or is_batch_full:
# Yield batch metadata
patch_centers = np.array(patch_centers, dtype=int)
nodes = np.array(nodes, dtype=int)
yield nodes, patch_centers
# Open a batch on the first interesting node we reach
if len(nodes) == 0:
if i in self._interesting_nodes and self.is_node_valid(i):
root = i
last_node = i
nodes.append(i)
else:
continue

# Reset batch metadata
# Yield when batch is full or has spread too far for prefetch
is_node_far = self.graph.dist(root, j) > 512
is_batch_full = len(nodes) == self.batch_size
if is_node_far or is_batch_full:
yield np.array(nodes, dtype=int)
nodes = list()
patch_centers = list()

# Visit j
if self.graph.degree[j] > 2:
# Visit j: same step_size cadence as Dense, gated on the
# interesting set so only branch- and proximity-region nodes
# are emitted.
is_next = self.graph.dist(last_node, j) >= self.step_size - 2
is_branching = self.graph.degree[j] >= 3
if (
(is_next or is_branching)
and j in self._interesting_nodes
and self.is_node_valid(j)
):
last_node = j
nodes.append(j)
patch_centers.append(self.graph.get_voxel(j))
if len(patch_centers) == 1:
if len(nodes) == 1:
root = j

if nodes:
yield np.array(nodes, dtype=int)

# --- Helpers ---
def estimate_iterations(self):
"""
Estimates the number of iterations required to search graph.

Returns
-------
int
Estimated number of iterations required to search graph.
Estimates the number of iterations required to search graph: the
size of the interesting set, divided by the step_size cadence used
within those regions.
"""
return len(self.graph.get_branchings())
step = max(1, self.step_size)
return max(1, len(self._interesting_nodes) // step)
Comment on lines +664 to +665
62 changes: 62 additions & 0 deletions src/neuron_proofreader/merge_proofreading/sparse_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Helpers for biased sparse inference-time sampling: pre-selects skeleton
nodes that sit near a branch point or near another axon, so that merge
detection can skip long, isolated axon segments.

Kept in its own module (only numpy + a networkx graph object) so it can be
exercised by unit tests without pulling in the full merge_inference image
and ML stack.
"""

import numpy as np


def compute_interesting_nodes(graph, branch_radius=25.0, proximity_radius=15.0):
"""
Selects nodes worth running merge detection on: the union of (a) nodes
within "branch_radius" graph-distance of a branching node, and (b) nodes
within "proximity_radius" Euclidean distance of a node belonging to a
different connected component.

Parameters
----------
graph : SkeletonGraph
Skeleton graph with "node_xyz", "node_component_id", and "kdtree"
populated. Must expose "get_branchings()", "neighbors(i)",
"dist(i, j)", and "set_kdtree()".
branch_radius : float, optional
Graph-distance window (microns) around branching nodes. Default 25.
proximity_radius : float, optional
Euclidean threshold (microns) for nodes treated as "near another
axon". Default 15 (matches "geometry_util.is_double_merge").

Returns
-------
Set[int]
Node IDs to sample at inference time.
"""
# Branch-region nodes: bounded DFS from every branching node
branch_set = set(graph.get_branchings())
queue = [(i, 0.0) for i in branch_set]
while queue:
i, dist_i = queue.pop()
for j in graph.neighbors(i):
dist_j = dist_i + graph.dist(i, j)
if j not in branch_set and dist_j < branch_radius:
branch_set.add(j)
queue.append((j, dist_j))

# Inter-component proximity nodes
if graph.kdtree is None:
graph.set_kdtree()
proximity_set = set()
for i in graph.nodes:
idxs = np.array(
graph.kdtree.query_ball_point(graph.node_xyz[i], proximity_radius)
)
if idxs.size and np.any(
graph.node_component_id[idxs] != graph.node_component_id[i]
):
proximity_set.add(i)

return branch_set | proximity_set