diff --git a/src/neuron_proofreader/machine_learning/train.py b/src/neuron_proofreader/machine_learning/train.py index 911698b..159ff4b 100644 --- a/src/neuron_proofreader/machine_learning/train.py +++ b/src/neuron_proofreader/machine_learning/train.py @@ -29,10 +29,32 @@ import torch import torch.distributed as dist import torch.nn as nn +import torch.nn.functional as F import torch.optim as optim from neuron_proofreader.utils import img_util, ml_util, util + +class FocalLoss(nn.Module): + """Binary focal loss for imbalanced classification. + + Downweights easy examples (high confidence, correct) so training + concentrates on the hard cases that drive false positives and false + negatives. Alpha upweights the positive class; gamma sharpens the + focus (gamma=0 reduces to standard BCE). + """ + + def __init__(self, alpha=0.25, gamma=2.0): + super().__init__() + self.alpha = alpha + self.gamma = gamma + + def forward(self, logits, targets): + bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none") + pt = torch.exp(-bce) + focal_weight = self.alpha * (1 - pt) ** self.gamma + return (focal_weight * bce).mean() + logger = logging.getLogger(__name__) _LOG_EVERY = 100 # batches between progress log lines @@ -83,6 +105,8 @@ def __init__( warmup_epochs=5, scheduler_type="cosine", pos_weight=None, + focal_gamma=None, + focal_alpha=0.25, save_val_logits=False, save_mistake_mips=False, on_best_model_saved=None, @@ -124,6 +148,7 @@ def __init__( # Instance attributes self.best_f1 = 0 self.best_val_loss = float("inf") + self.best_f1_at_95recall = 0.0 self.device = device self.log_dir = log_dir self.max_epochs = max_epochs @@ -138,11 +163,15 @@ def __init__( self.save_mistake_mips = save_mistake_mips self.on_best_model_saved = on_best_model_saved - if pos_weight is None: + if focal_gamma is not None: + self.criterion = FocalLoss(alpha=focal_alpha, gamma=focal_gamma) + print(f"Loss: FocalLoss(alpha={focal_alpha}, gamma={focal_gamma})") + elif pos_weight is None: self.criterion = nn.BCEWithLogitsLoss() else: pos_weight_tensor = torch.tensor([pos_weight], device=device) self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor) + print(f"Loss: BCEWithLogitsLoss(pos_weight={pos_weight})") self.model = model.to(device) self.optimizer = optim.AdamW( self._build_param_groups(self.model, lr, head_lr), @@ -256,16 +285,29 @@ def run(self, train_dataloader, val_dataloader): # Train-Validate train_stats = self.train_step(train_dataloader, epoch) val_stats = self.validate_step(val_dataloader, epoch) + new_best_loss = val_stats["loss"] < self.best_val_loss if new_best_loss: self.best_val_loss = val_stats["loss"] + + f1_95 = val_stats.get("f1_at_95recall", 0.0) + new_best_f1_95 = f1_95 > self.best_f1_at_95recall + if new_best_f1_95: + self.best_f1_at_95recall = f1_95 + + # Checkpoint: use F1@95recall once the model achieves it; fall back + # to val loss before that threshold is first reached. + if new_best_f1_95: + self.save_model(epoch, tag="best_f1_at_95recall") + if self.save_val_logits: + self._save_val_logits( + val_dataloader, self._last_val_y, self._last_val_hat_y, epoch + ) + elif new_best_loss and self.best_f1_at_95recall == 0.0: self.save_model(epoch, tag="best_loss") if self.save_val_logits: self._save_val_logits( - val_dataloader, - self._last_val_y, - self._last_val_hat_y, - epoch, + val_dataloader, self._last_val_y, self._last_val_hat_y, epoch ) # Log learning rate @@ -274,7 +316,12 @@ def run(self, train_dataloader, val_dataloader): self.writer.add_scalar("lr", current_lr, epoch) # Report results - print(f"\nEpoch {epoch}: " + ("New Best!" if new_best_loss else " ")) + is_new_best = new_best_f1_95 or (new_best_loss and self.best_f1_at_95recall == 0.0) + criterion_label = ( + f"F1@95R={f1_95:.4f}" if self.best_f1_at_95recall > 0.0 + else f"loss={val_stats['loss']:.4f}" + ) + print(f"\nEpoch {epoch}: " + (f"New Best! ({criterion_label})" if is_new_best else "")) self.report_stats(train_stats, is_train=True) self.report_stats(val_stats, is_train=False) @@ -289,8 +336,8 @@ def run(self, train_dataloader, val_dataloader): if new != old: print(f" LR reduced: group {i} {old:.2e} -> {new:.2e}") - # Early stopping check - if new_best_loss: + # Early stopping: track whichever criterion is active + if is_new_best: self.epochs_without_improvement = 0 else: self.epochs_without_improvement += 1 @@ -485,6 +532,27 @@ def forward_pass(self, x, y): return hat_y, loss # --- Helpers --- + @staticmethod + def _f1_at_recall_target(y, hat_y_logits, recall_target=0.95): + """Return the best F1 achievable at >= recall_target recall. + + Sweeps 200 probability thresholds and returns the maximum F1 among + those where recall >= recall_target. Returns 0.0 if the model never + achieves the target recall at any threshold. + """ + y_arr = np.array(y, dtype=int) + probs = 1.0 / (1.0 + np.exp(-np.array(hat_y_logits))) + thresholds = np.unique(np.percentile(probs, np.linspace(0, 100, 200))) + best_f1 = 0.0 + for t in thresholds: + preds = (probs >= t).astype(int) + r = recall_score(y_arr, preds, zero_division=0) + if r >= recall_target: + p = precision_score(y_arr, preds, zero_division=0) + f1 = 2 * p * r / max(p + r, 1e-8) + best_f1 = max(best_f1, f1) + return best_f1 + @staticmethod def compute_stats(y, hat_y): """ @@ -515,8 +583,10 @@ def compute_stats(y, hat_y): avg_recall = recall_score(y, hat_y, zero_division=np.nan) avg_f1 = 2 * avg_prec * avg_recall / max((avg_prec + avg_recall), 1e-8) avg_acc = accuracy_score(y, hat_y) + f1_at_95recall = Trainer._f1_at_recall_target(y, hat_y_arr) stats = { "f1": avg_f1, + "f1_at_95recall": f1_at_95recall, "precision": avg_prec, "recall": avg_recall, "accuracy": avg_acc, @@ -800,6 +870,8 @@ def __init__( warmup_epochs=5, scheduler_type="cosine", pos_weight=None, + focal_gamma=None, + focal_alpha=0.25, save_val_logits=False, save_mistake_mips=False ): @@ -876,6 +948,7 @@ def __init__( # Now initialize parent class attributes without creating directories self.best_f1 = 0 self.best_val_loss = float("inf") + self.best_f1_at_95recall = 0.0 self.device = device self.log_dir = log_dir self.max_epochs = max_epochs @@ -889,11 +962,17 @@ def __init__( self.save_val_logits = save_val_logits self.save_mistake_mips = save_mistake_mips - if pos_weight is None: + if focal_gamma is not None: + self.criterion = FocalLoss(alpha=focal_alpha, gamma=focal_gamma) + if self.rank == 0: + print(f"Loss: FocalLoss(alpha={focal_alpha}, gamma={focal_gamma})") + elif pos_weight is None: self.criterion = nn.BCEWithLogitsLoss() else: pos_weight_tensor = torch.tensor([pos_weight], device=device) self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor) + if self.rank == 0: + print(f"Loss: BCEWithLogitsLoss(pos_weight={pos_weight})") self.model = model.to(device) self.scaler = torch.cuda.amp.GradScaler(enabled=True) @@ -1092,15 +1171,32 @@ def run(self, train_dataloader, val_dataloader): new_best_loss = val_stats["loss"] < self.best_val_loss if new_best_loss: self.best_val_loss = val_stats["loss"] + + f1_95 = val_stats.get("f1_at_95recall", 0.0) + new_best_f1_95 = f1_95 > self.best_f1_at_95recall + if new_best_f1_95: + self.best_f1_at_95recall = f1_95 + + # Checkpoint: F1@95recall once achieved, val loss as fallback + if new_best_f1_95: + self.save_model(epoch, tag="best_f1_at_95recall") + if self.save_val_logits: + self._save_val_logits( + val_dataloader, self._last_val_y, self._last_val_hat_y, epoch + ) + elif new_best_loss and self.best_f1_at_95recall == 0.0: self.save_model(epoch, tag="best_loss") if self.save_val_logits: self._save_val_logits( - val_dataloader, - self._last_val_y, - self._last_val_hat_y, - epoch, + val_dataloader, self._last_val_y, self._last_val_hat_y, epoch ) - print(f"\nEpoch {epoch}: ", "New Best!" if new_best_loss else "") + + is_new_best = new_best_f1_95 or (new_best_loss and self.best_f1_at_95recall == 0.0) + criterion_label = ( + f"F1@95R={f1_95:.4f}" if self.best_f1_at_95recall > 0.0 + else f"loss={val_stats['loss']:.4f}" + ) + print(f"\nEpoch {epoch}: " + (f"New Best! ({criterion_label})" if is_new_best else "")) self.report_stats(train_stats, is_train=True) self.report_stats(val_stats, is_train=False) @@ -1117,8 +1213,8 @@ def run(self, train_dataloader, val_dataloader): print(f" LR reduced: group {i} {old:.2e} -> {new:.2e}") if rank == 0: - # Early stopping check - if new_best_loss: + # Early stopping: track whichever criterion is active + if is_new_best: self.epochs_without_improvement = 0 else: self.epochs_without_improvement += 1 diff --git a/src/neuron_proofreader/machine_learning/vision_models.py b/src/neuron_proofreader/machine_learning/vision_models.py index a20a91b..cd5393d 100644 --- a/src/neuron_proofreader/machine_learning/vision_models.py +++ b/src/neuron_proofreader/machine_learning/vision_models.py @@ -177,7 +177,12 @@ def __init__( self.encoder_dim = self.model.encoder_dim self.n_prefix_tokens = 1 + encoder.n_register_tokens self.grid_size = tuple(int(g) for g in encoder.grid_size) - self.pool_power = pool_power + # Learnable pooling power γ (log-parameterized so it stays positive). + # exp(log(pool_power)) = pool_power at init, so skeleton tokens start + # pool_power× heavier than segment tokens and background stays zero. + self.pool_log_power = nn.Parameter( + torch.tensor(float(pool_power)).log() + ) # Dual-stream classifier: [CLS, mask-pooled] → 1 self.classifier = nn.Sequential( @@ -203,8 +208,8 @@ def forward(self, x): weights = F.adaptive_max_pool3d(mask, self.grid_size) weights = weights.reshape(weights.shape[0], -1) # (B, n_patches) - # Power-scale: skeleton=1.0, segment=0.25, bg=0.0 (with power=2) - weights = weights ** self.pool_power + # Power-scale with learned γ; exp keeps γ strictly positive. + weights = weights ** self.pool_log_power.exp() # Normalize weights to sum to 1 weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-8) diff --git a/src/neuron_proofreader/merge_proofreading/merge_datasets.py b/src/neuron_proofreader/merge_proofreading/merge_datasets.py index 19c957a..eed011c 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datasets.py @@ -452,12 +452,21 @@ def get_random_negative_site(self): outcome = random.random() while True: # Sample node - if outcome < 0.4: + if outcome < 0.1: + # Near-merge hard negative: node 25–80 µm from a known merge + # site but confirmed not a merge itself. These are the hardest + # false positives the detector encounters at deployment. + node = self._sample_near_merge_negative(brain_id) + if node is None: + outcome = random.random() + continue + subgraph = self.graphs[brain_id].get_rooted_subgraph( + node, self.subgraph_radius + ) + return brain_id, subgraph, 0 + elif outcome < 0.44: # Any node node = util.sample_once(list(self.graphs[brain_id].nodes)) - #elif outcome < 0.5: - # # Node close to soma - # node = self.sample_node_nearby_soma(brain_id) elif outcome < 0.8: # Branching node branching_nodes = self.graphs[brain_id].get_branchings() @@ -491,6 +500,31 @@ def get_random_negative_site(self): if not self.is_nearby_merge_site(brain_id, node): return brain_id, subgraph, 0 + def _sample_near_merge_negative(self, brain_id, min_dist=25.0, max_dist=80.0, max_tries=20): + """ + Sample a fragment node that is near (but not at) a merge site. + + Nodes between min_dist and max_dist µm from the nearest GT merge site + are the hardest false positives at deployment. Adding ~10% of training + negatives from this zone makes the model discriminate them better. + + Returns None after max_tries failed attempts (e.g. sparse brain). + """ + if brain_id not in self.merge_site_kdtrees: + return None + kdtree = self.merge_site_kdtrees[brain_id] + graph = self.graphs[brain_id] + nodes = list(graph.nodes) + if not nodes: + return None + for _ in range(max_tries): + node = util.sample_once(nodes) + xyz = graph.node_xyz[node] + dist, _ = kdtree.query(xyz) + if min_dist <= dist <= max_dist: + return node + return None + def get_img_patch(self, brain_id, center): """ Extracts and normalizes a 3D image patch from the specified whole- @@ -537,8 +571,8 @@ def get_segment_mask(self, brain_id, center, subgraph): else: segment_mask = np.zeros(self.patch_shape) - # Annotate fragment - center = subgraph.get_voxel(0) + # Annotate fragment — use the passed center so translation augmentation + # shifts the skeleton overlay to match the shifted image read window. offset = img_util.get_offset(center, self.patch_shape) for node1, node2 in subgraph.edges: # Get local voxel coordinates @@ -663,7 +697,7 @@ class MergeSiteTrainDataset(MergeSiteDataset): A class for storing and retrieving training examples. """ - def __init__(self, base_dataset=None, idxs=None, negative_bias=0): + def __init__(self, base_dataset=None, idxs=None, negative_bias=0, max_translation=20): """ Instantiates a MergeSiteTrainDataset object. @@ -675,6 +709,10 @@ def __init__(self, base_dataset=None, idxs=None, negative_bias=0): Indices of examples to be kept in train dataset. negative_bias : float, optional Specifies percentage of additional negative examples to add. + max_translation : int, optional + Maximum voxel shift applied to the patch read center along each + axis during training. Shifts the merge site off-center to improve + robustness to misaligned inputs. Default is 20 voxels. """ # Create sub-dataset subset_dataset = base_dataset.subset(self.__class__, idxs) @@ -682,6 +720,7 @@ def __init__(self, base_dataset=None, idxs=None, negative_bias=0): # Instance attributes self.negative_bias = negative_bias + self.max_translation = max_translation self.transform = ImageTransforms() # --- Getters --- @@ -704,7 +743,25 @@ def __getitem__(self, idx): label : int 1 if the example is positive and 0 otherwise. """ - patches, subgraph, label = super().__getitem__(idx) + brain_id, subgraph, label = self.get_site(idx) + voxel = subgraph.get_voxel(0) + + # Random translation: shift the read window so the site appears + # off-center, training the model to be robust to misaligned inputs. + if self.max_translation > 0: + delta = np.random.randint(-self.max_translation, self.max_translation + 1, 3) + voxel = tuple(int(v + d) for v, d in zip(voxel, delta)) + + img_patch = self.get_img_patch(brain_id, voxel) + segment_mask = self.get_segment_mask(brain_id, voxel, subgraph) + + try: + patches = np.stack([img_patch, segment_mask], axis=0) + except ValueError: + img_patch = img_util.pad_to_shape(img_patch, self.patch_shape) + patches = np.stack([img_patch, segment_mask], axis=0) + + patches[0] = (patches[0] - patches[0].mean()) / (patches[0].std() + 1e-8) patches = self.transform(patches) return patches, subgraph, label diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py index 26addf5..88c0fb8 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ b/src/neuron_proofreader/merge_proofreading/merge_inference.py @@ -24,6 +24,9 @@ from neuron_proofreader.machine_learning.point_cloud_models import ( subgraph_to_point_cloud, ) +from neuron_proofreader.merge_proofreading.sparse_sampling import ( + compute_interesting_nodes, +) from neuron_proofreader.utils import ( geometry_util, img_util, @@ -556,7 +559,13 @@ def estimate_iterations(self): return int(length / self.step_size) -class SparseGraphDataset(GraphDataset): +class SparseGraphDataset(DenseGraphDataset): + """ + Inference dataset that samples only nodes near branch points or near + other axons, skipping long isolated axon segments. Inherits the + image-prefetch and per-node feature extraction from DenseGraphDataset + and overrides only the node-selection logic. + """ def __init__( self, @@ -564,72 +573,93 @@ def __init__( img_path, patch_shape, batch_size=16, + branch_radius=25.0, + brightness_clip=300, is_multimodal=False, min_search_size=0, prefetch=128, + proximity_radius=15.0, segmentation_path=None, + step_size=10, subgraph_radius=100, - use_new_mask=False + use_new_mask=False, ): - # Call parent class super().__init__( graph, img_path, patch_shape, batch_size=batch_size, + brightness_clip=brightness_clip, is_multimodal=is_multimodal, min_search_size=min_search_size, prefetch=prefetch, segmentation_path=segmentation_path, + step_size=step_size, subgraph_radius=subgraph_radius, - use_new_mask=use_new_mask + use_new_mask=use_new_mask, ) # Instance attributes - self.search_mode = "branching_points" - - def _generate_batches_from_component(self): - pass + self.search_mode = "biased_sparse" + self.branch_radius = branch_radius + self.proximity_radius = proximity_radius + self._interesting_nodes = compute_interesting_nodes( + graph, + branch_radius=branch_radius, + proximity_radius=proximity_radius, + ) def _generate_batch_nodes(self, root): + """ + Iterates the connected component containing "root" via DFS at the + same "step_size" cadence as DenseGraphDataset, but emits only nodes + pre-selected by "compute_interesting_nodes". Long, isolated axon + segments contribute zero samples. + """ nodes = list() - patch_centers = list() for i, j in nx.dfs_edges(self.graph, source=root): - # Check if starting new batch self.distance_traversed += self.graph.dist(i, j) - if len(patch_centers) == 0 and self.graph.degree[i] > 2: - root = i - nodes.append(i) - patch_centers.append(self.graph.get_voxel(i)) - # Check whether to yield batch - is_node_far = self.graph.dist(root, j) > 256 - is_batch_full = len(patch_centers) == self.batch_size - if is_node_far or is_batch_full: - # Yield batch metadata - patch_centers = np.array(patch_centers, dtype=int) - nodes = np.array(nodes, dtype=int) - yield nodes, patch_centers + # Open a batch on the first interesting node we reach + if len(nodes) == 0: + if i in self._interesting_nodes and self.is_node_valid(i): + root = i + last_node = i + nodes.append(i) + else: + continue - # Reset batch metadata + # Yield when batch is full or has spread too far for prefetch + is_node_far = self.graph.dist(root, j) > 512 + is_batch_full = len(nodes) == self.batch_size + if is_node_far or is_batch_full: + yield np.array(nodes, dtype=int) nodes = list() - patch_centers = list() - # Visit j - if self.graph.degree[j] > 2: + # Visit j: same step_size cadence as Dense, gated on the + # interesting set so only branch- and proximity-region nodes + # are emitted. + is_next = self.graph.dist(last_node, j) >= self.step_size - 2 + is_branching = self.graph.degree[j] >= 3 + if ( + (is_next or is_branching) + and j in self._interesting_nodes + and self.is_node_valid(j) + ): + last_node = j nodes.append(j) - patch_centers.append(self.graph.get_voxel(j)) - if len(patch_centers) == 1: + if len(nodes) == 1: root = j + if nodes: + yield np.array(nodes, dtype=int) + # --- Helpers --- def estimate_iterations(self): """ - Estimates the number of iterations required to search graph. - - Returns - ------- - int - Estimated number of iterations required to search graph. + Estimates the number of iterations required to search graph: the + size of the interesting set, divided by the step_size cadence used + within those regions. """ - return len(self.graph.get_branchings()) + step = max(1, self.step_size) + return max(1, len(self._interesting_nodes) // step) diff --git a/src/neuron_proofreader/merge_proofreading/sparse_sampling.py b/src/neuron_proofreader/merge_proofreading/sparse_sampling.py new file mode 100644 index 0000000..7a35198 --- /dev/null +++ b/src/neuron_proofreader/merge_proofreading/sparse_sampling.py @@ -0,0 +1,62 @@ +""" +Helpers for biased sparse inference-time sampling: pre-selects skeleton +nodes that sit near a branch point or near another axon, so that merge +detection can skip long, isolated axon segments. + +Kept in its own module (only numpy + a networkx graph object) so it can be +exercised by unit tests without pulling in the full merge_inference image +and ML stack. +""" + +import numpy as np + + +def compute_interesting_nodes(graph, branch_radius=25.0, proximity_radius=15.0): + """ + Selects nodes worth running merge detection on: the union of (a) nodes + within "branch_radius" graph-distance of a branching node, and (b) nodes + within "proximity_radius" Euclidean distance of a node belonging to a + different connected component. + + Parameters + ---------- + graph : SkeletonGraph + Skeleton graph with "node_xyz", "node_component_id", and "kdtree" + populated. Must expose "get_branchings()", "neighbors(i)", + "dist(i, j)", and "set_kdtree()". + branch_radius : float, optional + Graph-distance window (microns) around branching nodes. Default 25. + proximity_radius : float, optional + Euclidean threshold (microns) for nodes treated as "near another + axon". Default 15 (matches "geometry_util.is_double_merge"). + + Returns + ------- + Set[int] + Node IDs to sample at inference time. + """ + # Branch-region nodes: bounded DFS from every branching node + branch_set = set(graph.get_branchings()) + queue = [(i, 0.0) for i in branch_set] + while queue: + i, dist_i = queue.pop() + for j in graph.neighbors(i): + dist_j = dist_i + graph.dist(i, j) + if j not in branch_set and dist_j < branch_radius: + branch_set.add(j) + queue.append((j, dist_j)) + + # Inter-component proximity nodes + if graph.kdtree is None: + graph.set_kdtree() + proximity_set = set() + for i in graph.nodes: + idxs = np.array( + graph.kdtree.query_ball_point(graph.node_xyz[i], proximity_radius) + ) + if idxs.size and np.any( + graph.node_component_id[idxs] != graph.node_component_id[i] + ): + proximity_set.add(i) + + return branch_set | proximity_set