diff --git a/README.md b/README.md index ce82b35..c0867cf 100644 --- a/README.md +++ b/README.md @@ -2,85 +2,212 @@ [![Tests](https://github.com/YosefLab/csde/actions/workflows/test.yml/badge.svg)](https://github.com/YosefLab/csde/actions/workflows/test.yml) -`csde` (Corrected Spatial Differential Expression) is a Python package designed to **identify differentially expressed (DE) genes between spatially-resolved cell populations** (e.g., T-cells inside vs. outside a tumor). +Automated pipelines for spatial transcriptomics produce cell quantifications (cell-by-gene expression matrices and label assignments) that contain systematic errors, e.g., due to mis-segmentation of cell boundaries. +These errors can propagate into downstream analyses of differential expression, leading to false discoveries or missed signals -Standard analysis relies on cell population assignments (e.g., "infiltrating" vs. "non-infiltrating") obtained automatically from clustering/ML that are often prone to errors. `csde` corrects for these inaccuracies by leveraging a small subset of validated "ground-truth" data, providing rigorous statistical guarantees for spatially-resolved DE analyses. +CSDE corrects for these errors by combining the large automated dataset with a small set of manually validated cells, using prediction-powered inference to recover unbiased estimates with valid confidence intervals. -Refer to the preprint and the [project repository](https://github.com/YosefLab/csde) for more details. +The current codebase focuses on the comparison of a given cell type across two spatial regions. +It allows users to +1. export per-cell annotation panels for a small subset of cells (e.g. 600) +2. manually validate the segmentation and type assignment for these cells +3. run the CSDE model to get corrected DE estimates for all genes + +Refer to the [preprint](https://www.biorxiv.org/content/10.64898/2026.01.15.699786v1) for details on the method. + +### Input requirements + +The workflow takes a [SpatialData](https://spatialdata.scverse.org/) zarr as input. Its `"table"` AnnData must contain: + +- **raw expression counts** in `.X` or a named layer +- **the following `obs` columns:** + +| obs column | content | +| --- | --- | +| `cell_type` (configurable) | cell-type label for each cell | +| `spatial_group` (configurable) | binary spatial region label (e.g. `0` = outside tumour, `1` = inside tumour) | +| `center_x`, `center_y` | cell centroid in microns | + +The zarr must also expose at least one **fluorescence image channel** (e.g. `"DAPI"`, `"Cellbound2"`) used to render the per-cell annotation panels. ## Installation ```bash pip install csde +pip install "csde[cuda12]" # GPU (CUDA 12) +pip install "csde[annotate]" # annotation UI (Step 2, requires streamlit) +pip install "csde[cuda12,annotate]" # both ``` -By default, this installs JAX with CPU support. To enable GPU support (CUDA), install with the appropriate extra (e.g., for CUDA 12): +## Workflow overview + +``` +SpatialData zarr + │ + ▼ +1. Export annotation panels ←─ scripts/export.py + (importance-sampled cells, + one image per cell) + │ + ▼ +2. Manual validation ←─ scripts/annotate.py + (annotator marks each cell + as correctly / incorrectly labelled) + │ + ▼ +3. Run CSDE ←─ scripts/differential_expression.py + (corrected DE estimates) +``` + +--- + +## Step 1 — Export annotation panels (`scripts/export.py`) + +Before running the statistical model, a small subset of cells must be manually validated. `csde` provides tooling to generate the per-cell images needed for that step. + ```bash -pip install "csde[cuda12]" +python scripts/export.py \ +--sdata /path/to/region.zarr \ +--out /path/to/annotation_dir \ +--cell-type-key cell_type \ +--cell-type-of-interest macrophages \ +--target-proportion 0.4 \ +--gene-colors scripts/gene_colors_file.json \ +--image-channel Cellbound2 \ +--n-cells 600 ``` -## Data Requirements +`--target-proportion` controls the fraction of cells of interest in the subsample. Cells of interest are upweighted accordingly (importance sampling); the unnormalized weight for each sampled cell is stored in `metadata.csv` for downstream use. -`csde` requires two `AnnData` objects containing gene expression counts. Typically, these are obtained by splitting your full dataset into two groups: +The script writes: -### 1. `adata_pred`: The dataset to analyze -This object contains the bulk of your cells (e.g., the majority of the tissue) where only standard (predicted) cell population assignments are available. +``` +/path/to/annotation_dir/ +├── images/ +│ ├── cell_.png # one panel per cell +│ └── ... +├── config.json # all export arguments (read by annotate.py) +├── metadata.csv # cell_id, cell_type, image_path, sampling_weight, center_x, center_y +└── annotations.json # {cell_id: true/false} — written by annotate.py +``` -**Requirements:** -* A column in `.obs` (e.g., `"cell_population"`) containing cell population labels (e.g., "T cell (infiltrating)" vs. "T cell (non-infiltrating)"). These labels can be derived from heuristics (e.g., distance to tumor) and/or computational classifiers. +Each panel contains: +- **Left** — fluorescence image crop + cell boundaries + transcript dots for genes listed in `gene_colors` +- **Right** — top expressed genes (bar chart); genes in `gene_colors` use their assigned colour, others are grey -### 2. `adata_gt`: The correction set -This object contains a small subset of randomly sampled cells whose cell population assignments have been **validated** to serve as a ground truth. This set allows `csde` to estimate the error rate of the standard predictions. +### Gene color file -**Requirements:** -* **Prediction column:** The same column name as in `adata_pred` (e.g., `"cell_population"`), containing the automated labels. -* **Validation column:** A **boolean** column in `.obs` (e.g., `"is_correct"`) indicating if the automated label matches the validation ground truth (see [How to construct `adata_gt`?](#how-to-construct-adata_gt)). +A simple JSON mapping gene names to colours: -## Usage +```json +{ + "CD68": "#e41a1c", + "MRC1": "#377eb8", + "C1QA": "#4daf4a", + "FCGR3A": "#ff7f00" +} +``` -```python -from csde import run_csde +
+Python API -results = run_csde( - # `AnnData` datasets to analyze - adata_pred=adata_pred, - adata_gt=adata_gt, - - # Column containing the predicted labels (in BOTH datasets) - pred_cell_pop_key="cell_population", - - # The two populations to compare - cell_pop_a="T-cell (infiltrating)", # Reference group - cell_pop_b="T-cell (non-infiltrating)", # Target group - - # Boolean column in adata_gt verifying the prediction - gt_key="is_correct", - - # Optional: Use a specific layer for counts (default uses .X) - layer_name="counts" +```python +import json +import spatialdata as sd +from csde import export_cell_panels, subsample_cells, plot_top_genes + +sdata = sd.read_zarr("/path/to/region.zarr") +gene_colors = json.load(open("gene_colors.json")) + +metadata = export_cell_panels( + sdata=sdata, + annotation_dir="/path/to/annotation_dir", + cell_type_key="cell_type", + cell_type_of_interest="macrophages", + target_proportion=0.4, + gene_colors=gene_colors, + image_channel="Cellbound2", + n_cells=600, ) +``` +
-# Returns a DataFrame with log_fold_change, p_value, and adjusted p_value -print(results.head()) +--- + +## Step 2 — Manual validation (`scripts/annotate.py`) + +For each exported image, an annotator decides whether the automated cell-type label is correct. The result is a boolean column `is_correct` added to `metadata.csv`, which becomes `adata_gt` in Step 3. + +```bash +streamlit run scripts/annotate.py -- --dir /path/to/annotation_dir +``` + +VS Code Remote forwards the Streamlit port automatically. Open the URL printed in the terminal, then use: + +- **`1`** — label as correct +- **`2`** — label as incorrect + +Progress is saved after every keypress to `annotations.json`. Re-running the command resumes from where you left off. You can also start annotating while `export.py` is still running — the UI picks up newly exported cells automatically. + +--- + +## Step 3 — Differential expression (`scripts/differential_expression.py`) + +```bash +python scripts/differential_expression.py --dir /path/to/annotation_dir ``` -### Output Columns -The returned DataFrame is indexed by gene name and contains: -* `log_fold_change`: The estimated log-fold change of expression (Target vs. Reference). Positive values indicate upregulation in `cell_pop_b`. -* `p_value`: The raw p-value from the hypothesis test (two-sided). -* `p_value_adj`: The p-value adjusted for multiple testing (Benjamini-Hochberg FDR). +Reads all export settings from `config.json` and writes gene-level results to `/results.csv`. + +| option | default | description | +|---|---|---| +| `--dir` | *(required)* | annotation directory (output of steps 1 & 2) | +| `--out` | `/results.csv` | output CSV path | +| `--spatial-group-key` | `spatial_group` | obs column encoding the two spatial populations | +| `--n-cells-expressed-threshold` | `10` | min annotated cells expressing a gene for it to be tested | +| `--noise-model` | `poisson` | `poisson` or `nb` (negative binomial) | + +### Output columns -## How to construct `adata_gt`? +| column | description | +|---|---| +| `log_fold_change` | estimated LFC (positive = upregulated in target population) | +| `p_value` | raw two-sided p-value | +| `p_value_adj` | Benjamini-Hochberg adjusted p-value | -Constructing `adata_gt` requires validating the cell population labels for a small subset of cells (e.g., random sample). This involves: -1. **Sampling**: Select a small random subset of cells from your dataset. -2. **Data Access**: Extract the relevant data for these cells: their gene expression profile, their spatial coordinates, and importantly, a **high-resolution image crop** of the cell (with segmentation boundaries if available) to assess morphology. -3. **Validation**: Visually inspect these data points to determine the true cell identity. -4. **Annotation**: Create the `is_correct` boolean column based on your assessment. +
+Python API -These steps can be performed manually or using dedicated tools. -Our [experimental repository](https://github.com/YosefLab/csde/blob/main/csde_experiments) -provides an example of how these steps were performed for MERFISH data. +The full CSDE statistical model is callable directly from Python, without going through the CLI scripts. -To streamline this process, for MERFISH or other spatial transcriptomics data, we recommend using **[SpatialData](https://spatialdata.scverse.org/)** to access the data and perform the manual validation. +`prepare_csde_inputs` reads `config.json`, `metadata.csv`, and `annotations.json` from the annotation directory produced by Steps 1 & 2. It returns two AnnData objects restricted to the same gene set: + +- `adata_gt` — the manually validated cells, with an `is_correct` boolean column in `.obs` and a `sampling_weight` column reflecting the importance-sampling weight assigned during export +- `adata_other` — all remaining cells (not manually validated); their `obs` must contain a `prediction` column (integer) encoding the spatial population each cell was assigned to by the automated pipeline: `0` = reference region, `1` = target region, `2` = neither + +```python +from csde import prepare_csde_inputs, run_csde + +inputs = prepare_csde_inputs( + annotation_dir="/path/to/annotation_dir", # same dir as Steps 1 & 2 + spatial_group_key="spatial_group", + n_cells_expressed_threshold=10, +) +adata_gt = inputs["adata_gt"] # manually validated cells +adata_other = inputs["adata_other"] # all other cells + +results = run_csde( + adata_pred=adata_other, + adata_gt=adata_gt, + pred_cell_pop_key="prediction", # obs column: 0=reference, 1=target, 2=other + cell_pop_a=0, # reference population + cell_pop_b=1, # target population (LFC = log(target/reference)) + gt_key="is_correct", # boolean correctness label from Step 2 + layer_name="counts", + importance_weights=adata_gt.obs["sampling_weight"].values, # from metadata.csv +) +# DataFrame indexed by gene: log_fold_change, p_value, p_value_adj +print(results.head()) +``` +
diff --git a/pyproject.toml b/pyproject.toml index c6da244..529d2f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dev = [ "isort", "flake8", ] +annotate = ["streamlit"] cuda12 = ["jax[cuda12]"] cuda13 = ["jax[cuda13]"] diff --git a/scripts/annotate.py b/scripts/annotate.py new file mode 100644 index 0000000..8419020 --- /dev/null +++ b/scripts/annotate.py @@ -0,0 +1,153 @@ +""" +Streamlit annotation UI for per-cell panels. + +Usage +----- +streamlit run scripts/annotate.py -- --dir /path/to/annotations/R2_macrophages +""" + +import argparse +import json +from pathlib import Path + +import pandas as pd +import streamlit as st +import streamlit.components.v1 as components + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--dir", required=True, help="Annotation directory (output of export.py).") + return p.parse_args() + + +def load_annotations(annotation_dir: Path) -> dict: + ann_path = annotation_dir / "annotations.json" + if ann_path.exists(): + with open(ann_path) as f: + return json.load(f) + return {} + + +def save_annotations(annotations: dict, annotation_dir: Path) -> None: + with open(annotation_dir / "annotations.json", "w") as f: + json.dump(annotations, f, indent=2) + + +def main(): + args = parse_args() + annotation_dir = Path(args.dir) + + cell_type_of_interest = "cell of interest" + config_path = annotation_dir / "config.json" + if config_path.exists(): + with open(config_path) as f: + config = json.load(f) + cell_type_of_interest = config.get("cell_type_of_interest", cell_type_of_interest) + + st.set_page_config(layout="wide", page_title="Cell Annotator") + st.title(f"Cell Annotation — {cell_type_of_interest}") + + metadata_path = annotation_dir / "metadata.csv" + if not metadata_path.exists(): + st.warning("metadata.csv not found — waiting for export.py to write the first cell.") + st.stop() + + metadata = pd.read_csv(metadata_path) + metadata["cell_id"] = metadata["cell_id"].astype(str) + n_total = len(metadata) + + annotations = load_annotations(annotation_dir) + n_done = len(annotations) + + st.progress(n_done / n_total, text=f"{n_done} / {n_total} annotated") + + # Initialize navigation index to first unannotated cell + if "current_idx" not in st.session_state: + unannotated_mask = ~metadata["cell_id"].isin(annotations) + first_unannotated = unannotated_mask.idxmax() if unannotated_mask.any() else 0 + st.session_state.current_idx = int(first_unannotated) + + idx = st.session_state.current_idx + + # Jump-to input (form so Enter submits without looping) + jump_col, nav_col = st.columns([2, 3]) + with jump_col: + with st.form("jump_form", clear_on_submit=True): + fc1, fc2 = st.columns([4, 1]) + with fc1: + jump_id = st.text_input("Jump to cell ID", placeholder="paste cell_id here", label_visibility="collapsed") + with fc2: + submitted = st.form_submit_button("Go") + if submitted and jump_id.strip(): + matches = metadata.index[metadata["cell_id"] == jump_id.strip()].tolist() + if matches: + st.session_state.current_idx = matches[0] + idx = matches[0] + else: + st.warning(f"Cell ID `{jump_id.strip()}` not found.") + + with nav_col: + nav1, nav2, nav3 = st.columns([1, 3, 1]) + with nav1: + if st.button("← Prev", use_container_width=True, disabled=(idx == 0)): + st.session_state.current_idx = idx - 1 + st.rerun() + with nav2: + st.markdown(f"
{idx + 1} / {n_total}
", unsafe_allow_html=True) + with nav3: + if st.button("Next →", use_container_width=True, disabled=(idx == n_total - 1)): + st.session_state.current_idx = idx + 1 + st.rerun() + + row = metadata.iloc[idx] + cell_id = row["cell_id"] + existing = annotations.get(cell_id) + status = "✓ correct" if existing is True else ("✗ incorrect" if existing is False else "not annotated") + + st.subheader(f"Cell `{cell_id}` — predicted: **{row['cell_type']}** — {status}") + st.image(str(row["image_path"]), use_container_width=True) + + def annotate(is_correct: bool) -> None: + annotations[cell_id] = is_correct + save_annotations(annotations, annotation_dir) + # Advance to next unannotated after annotating + remaining = metadata.index[~metadata["cell_id"].isin(annotations)] + next_idx = next((i for i in remaining if i > idx), None) + if next_idx is not None: + st.session_state.current_idx = int(next_idx) + elif remaining.any(): + st.session_state.current_idx = int(remaining[0]) + + col1, col2, _ = st.columns([1, 1, 4]) + with col1: + if st.button("✓ Correct [1]", type="primary", use_container_width=True): + annotate(True) + st.rerun() + with col2: + if st.button("✗ Incorrect [2]", use_container_width=True): + annotate(False) + st.rerun() + + # Keyboard shortcuts: 1/2 annotate, ←/→ navigate + components.html(""" + + """, height=0) + + +if __name__ == "__main__": + main() diff --git a/scripts/export.py b/scripts/export.py new file mode 100644 index 0000000..7030c03 --- /dev/null +++ b/scripts/export.py @@ -0,0 +1,84 @@ +""" +Export per-cell annotation panels from a SpatialData zarr. + +Example +------- +python scripts/export.py \ + --sdata /ewsc/pboyeau/data/processed/region_R2_annotated.zarr \ + --out /ewsc/pboyeau/data/annotations/R2_macrophages \ + --cell-type-key cell_type \ + --cell-type-of-interest macrophages \ + --target-proportion 0.4 \ + --gene-colors scripts/gene_colors_macrophages.json \ + --image-channel Cellbound2 \ + --n-cells 600 +""" + +import argparse +import json +from pathlib import Path + +import spatialdata as sd + +from csde import export_cell_panels + + +def parse_args(): + p = argparse.ArgumentParser(description="Export per-cell annotation panels.") + p.add_argument("--sdata", required=True, help="Path to annotated SpatialData zarr.") + p.add_argument("--out", required=True, help="Output annotation directory.") + p.add_argument("--cell-type-key", default="cell_type") + p.add_argument("--cell-type-of-interest", required=True) + p.add_argument("--target-proportion", type=float, required=True, + help="Desired fraction of cells of interest in the subsample.") + p.add_argument("--gene-colors", default=None, + help="JSON file mapping gene name → colour.") + p.add_argument("--image-channel", default="DAPI") + p.add_argument("--n-cells", type=int, default=600) + p.add_argument("--delta", type=float, default=50.0, + help="Half-width of the spatial crop around each cell (microns).") + p.add_argument("--n-top-genes", type=int, default=15) + p.add_argument("--layer", default=None, + help="AnnData layer for expression counts (default: X).") + p.add_argument("--seed", type=int, default=0) + p.add_argument("--dpi", type=int, default=150) + return p.parse_args() + + +def main(): + args = parse_args() + + gene_colors = None + if args.gene_colors: + with open(args.gene_colors) as f: + gene_colors = json.load(f) + + sdata = sd.read_zarr(args.sdata) + + annotation_dir = Path(args.out) + annotation_dir.mkdir(parents=True, exist_ok=True) + + with open(annotation_dir / "config.json", "w") as f: + json.dump(vars(args), f, indent=2) + + metadata = export_cell_panels( + sdata=sdata, + annotation_dir=annotation_dir, + cell_type_key=args.cell_type_key, + cell_type_of_interest=args.cell_type_of_interest, + target_proportion=args.target_proportion, + n_cells=args.n_cells, + image_channel=args.image_channel, + delta=args.delta, + n_top_genes=args.n_top_genes, + layer=args.layer, + gene_colors=gene_colors, + seed=args.seed, + dpi=args.dpi, + ) + print(f"Done. {len(metadata)} cells exported to {args.out}") + print(metadata["cell_type"].value_counts().to_string()) + + +if __name__ == "__main__": + main() diff --git a/src/csde/__init__.py b/src/csde/__init__.py index 3292a89..13d6ad5 100644 --- a/src/csde/__init__.py +++ b/src/csde/__init__.py @@ -1,3 +1,25 @@ +from .annotation import export_cell_panels, load_annotations, prepare_csde_inputs from .api import run_csde +from .model_nb import NBIntercept, NBInterceptModule +from .model_poisson import PoissonIntercept, PoissonInterceptModule +from .spatial_utils import ( + compute_importance_weights, + plot_region, + plot_top_genes, + subsample_cells, +) -__all__ = ["run_csde"] +__all__ = [ + "run_csde", + "PoissonIntercept", + "PoissonInterceptModule", + "NBIntercept", + "NBInterceptModule", + "plot_region", + "plot_top_genes", + "compute_importance_weights", + "subsample_cells", + "export_cell_panels", + "load_annotations", + "prepare_csde_inputs", +] diff --git a/src/csde/_base.py b/src/csde/_base.py new file mode 100644 index 0000000..bc5e322 --- /dev/null +++ b/src/csde/_base.py @@ -0,0 +1,111 @@ +from typing import List, Optional, Tuple, Union + +import numpy as np + + +class PPIAbstractClass: + def __init__( + self, + inputs_gt: Union[Tuple[np.ndarray, np.ndarray], np.ndarray], + inputs_hat: Union[Tuple[np.ndarray, np.ndarray], np.ndarray], + inputs_unl: Union[Tuple[np.ndarray, np.ndarray], np.ndarray], + lambd_mode: str = "overall", + ): + self.inputs_gt = inputs_gt + self.inputs_hat = inputs_hat + self.inputs_unl = inputs_unl + + inputs_are_tuples = isinstance(inputs_gt, tuple) + if inputs_are_tuples: + self.n = inputs_gt[0].shape[0] + self.N = inputs_unl[0].shape[0] + else: + self.n = self.inputs_gt.shape[0] + self.N = self.inputs_unl.shape[0] + self.r = float(self.n) / self.N + self.theta = None + self.sigma = None + self.hessian = None + self.v = None + self.lambd_mode = lambd_mode + self.lambd_ = None + + def get_asymptotic_distribution(self) -> Tuple[np.ndarray, np.ndarray]: + self.sigma = self.compute_sigma(self.lambd_) + return self.theta, self.sigma + + def compute_sigma(self, lambd: Union[float, np.ndarray]) -> np.ndarray: + grad_f_unl = self.grad_fn(self.inputs_unl) + grad_f_hat = self.grad_fn(self.inputs_hat) + grad_f_all = np.vstack([grad_f_hat, grad_f_unl]) + grad_f_gt = self.grad_fn(self.inputs_gt) + + grad_f_ = grad_f_all - grad_f_all.mean(axis=0) + vf = (lambd**2) * (grad_f_.T @ grad_f_) / (self.n + self.N) + rect_ = grad_f_gt - lambd * grad_f_hat + rect_ = rect_ - rect_.mean(axis=0) + vdelta = (rect_.T @ rect_) / self.n + v = vdelta + (self.r * vf) + + hess = self.hessian_fn(self.inputs_gt) + self.hessian = hess + self.v = v + return self._compute_sigma(hess, v, self.n) + + @staticmethod + def _compute_sigma(hess: np.ndarray, v: np.ndarray, n: int) -> np.ndarray: + inv_hess = np.linalg.pinv(hess) + sigma_ = inv_hess @ v @ inv_hess + sigma_ = sigma_ / n + return sigma_ + + def get_lambda( + self, + lambd_0: float = 0.5, + idx_to_optimize: Optional[Union[int, List[int]]] = None, + ) -> Union[float, np.ndarray]: + print("get point estimate ...") + self.theta = self.get_pointestimate(lambd_=lambd_0) + print("done") + + hess = self.hessian_fn(self.inputs_gt) + inv_hess = np.linalg.pinv(hess) + grad_f_unl = self.grad_fn(self.inputs_unl) + grad_f_hat = self.grad_fn(self.inputs_hat) + grad_f_all = np.vstack([grad_f_hat, grad_f_unl]) + grad_f_gt = self.grad_fn(self.inputs_gt) + + grad_f_hat_ = grad_f_hat - grad_f_hat.mean(0) + grad_f_gt_ = grad_f_gt - grad_f_gt.mean(0) + cov1 = (grad_f_hat_.T @ grad_f_gt_) / self.n + cov2 = (grad_f_gt_.T @ grad_f_hat_) / self.n + + grad_f_ = grad_f_all - grad_f_all.mean(axis=0) + vf = (grad_f_.T @ grad_f_) / (self.n + self.N) + num = inv_hess @ (cov1 + cov2) @ inv_hess + denom = 2 * (1.0 + self.r) * (inv_hess @ vf @ inv_hess) + if self.lambd_mode == "element": + lambd_star = num / denom + return np.diag(lambd_star) + elif idx_to_optimize is not None: + print("optimize lambda for a single theta comp.") + if isinstance(idx_to_optimize, int): + return ( + num[idx_to_optimize, idx_to_optimize] + / denom[idx_to_optimize, idx_to_optimize] + ) + else: + return np.trace(num[idx_to_optimize, :][:, idx_to_optimize]) / np.trace( + denom[idx_to_optimize, :][:, idx_to_optimize] + ) + else: + return np.trace(num) / np.trace(denom) + + def get_pointestimate(self, lambd_: float) -> np.ndarray: + raise NotImplementedError + + def grad_fn(self, inputs: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: + raise NotImplementedError + + def hessian_fn(self, inputs: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: + raise NotImplementedError diff --git a/src/csde/annotation.py b/src/csde/annotation.py new file mode 100644 index 0000000..e779481 --- /dev/null +++ b/src/csde/annotation.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +import io +import json +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd + +from .spatial_utils import plot_region, plot_top_genes, subsample_cells + + +def prepare_csde_inputs( + annotation_dir: str | Path, + sdata=None, + spatial_group_key: str = "spatial_group", + spatial_group_target=1, + spatial_group_reference=0, + layer: str | None = None, + n_cells_expressed_threshold: int = 10, +) -> dict: + """ + Build adata_gt and adata_other for run_csde() from a completed annotation directory. + + Reads config.json (cell_type_key, cell_type_of_interest, sdata path), + metadata.csv (sampling_weight per annotated cell), and annotations.json + (is_correct per cell) from annotation_dir. + + Label encoding in .obs["prediction"] / .obs["annotation"]: + 1 — cell_type_of_interest with spatial_group == 1 (target) + 0 — cell_type_of_interest with spatial_group == 0 (reference) + 2 — all other cells + + For GT annotation labels, cells predicted as cell_type_of_interest but + marked incorrect (is_correct=False) are reassigned to class 2. + + Parameters + ---------- + annotation_dir + Directory produced by scripts/export.py and scripts/annotate.py. + sdata + Already-loaded SpatialData object. If None, loaded from the path + stored in config.json. + spatial_group_key + obs column encoding the two spatial populations. + spatial_group_target + Value in spatial_group_key that identifies the target population (label 1). + spatial_group_reference + Value in spatial_group_key that identifies the reference population (label 0). + layer + AnnData layer to use for expression counts. Defaults to None (uses .X). + n_cells_expressed_threshold + A gene is kept only if it is expressed (count >= 1) in at least this + many annotated pred-target/reference cells. + + Returns + ------- + dict with keys: + + adata_gt : AnnData + Annotated cells. obs columns added: ``prediction`` (int 0/1/2), + ``annotation`` (int 0/1/2), ``is_correct`` (bool), + ``sampling_weight`` (float). Genes are filtered. + adata_other : AnnData + All unannotated cells. obs column added: ``prediction`` (int 0/1/2). + Same gene set as adata_gt. + """ + import numpy as np + + annotation_dir = Path(annotation_dir) + + with open(annotation_dir / "config.json") as f: + config = json.load(f) + cell_type_key = config["cell_type_key"] + cell_type_of_interest = config["cell_type_of_interest"] + + ann_path = annotation_dir / "annotations.json" + if not ann_path.exists(): + raise FileNotFoundError( + f"No annotations found at {ann_path}. Run scripts/annotate.py first." + ) + with open(ann_path) as f: + annotations = json.load(f) # {cell_id: True/False} + + metadata = pd.read_csv(annotation_dir / "metadata.csv") + metadata["cell_id"] = metadata["cell_id"].astype(str) + sampling_weights = metadata.set_index("cell_id")["sampling_weight"] + + if sdata is None: + import spatialdata as sd + sdata = sd.read_zarr(config["sdata"]) + adata = sdata["table"].copy() + adata.obs_names = adata.obs_names.astype(str) + adata = adata[adata.obs[cell_type_key].notna()].copy() + + # --- Prediction labels (automated, all cells) --- + is_coi = (adata.obs[cell_type_key] == cell_type_of_interest).values + spatial_group = adata.obs[spatial_group_key].values + + prediction = np.full(len(adata), 2, dtype=int) + prediction[is_coi & (spatial_group == spatial_group_target)] = 1 # target + prediction[is_coi & (spatial_group == spatial_group_reference)] = 0 # reference + adata.obs["prediction"] = prediction + + # --- Split annotated / unannotated --- + annotated_ids = set(annotations.keys()) + annotated_mask = adata.obs_names.isin(annotated_ids) + + adata_gt = adata[annotated_mask].copy() + + is_correct_arr = np.array( + [annotations[cid] for cid in adata_gt.obs_names], dtype=bool + ) + adata_gt.obs["is_correct"] = is_correct_arr + + # --- Annotation (GT) labels --- + is_coi_gt = (adata_gt.obs[cell_type_key] == cell_type_of_interest).values + spatial_group_gt = adata_gt.obs[spatial_group_key].values + + annotation = np.full(len(adata_gt), 2, dtype=int) + annotation[is_coi_gt & is_correct_arr & (spatial_group_gt == spatial_group_target)] = 1 + annotation[is_coi_gt & is_correct_arr & (spatial_group_gt == spatial_group_reference)] = 0 + adata_gt.obs["annotation"] = annotation + + adata_gt.obs["sampling_weight"] = adata_gt.obs_names.map(sampling_weights).values + + # --- Gene filter: expressed in >= threshold pred-target/ref cells in adata_gt --- + # pred_mask = adata_gt.obs["prediction"].isin([0, 1]) + pred_mask = adata_gt.obs["annotation"].isin([0, 1]) + _sub = adata_gt[pred_mask] + x = _sub.layers[layer] if layer is not None else _sub.X + if hasattr(x, "toarray"): + x = x.toarray() + x = x.astype(float) + n_expressing = np.array((x >= 1).sum(0)).flatten() + gene_mask = n_expressing >= n_cells_expressed_threshold + + adata_gt = adata_gt[:, gene_mask].copy() + adata_other = adata[~annotated_mask][:, gene_mask].copy() + + return {"adata_gt": adata_gt, "adata_other": adata_other} + + +def load_annotations(annotation_dir: str | Path) -> pd.DataFrame: + """ + Merge ``metadata.csv`` and ``annotations.json`` into a single DataFrame. + + Returns only annotated cells, with an added boolean ``is_correct`` column. + Pass the result as ``adata_gt`` to :func:`~csde.run_csde`. + """ + annotation_dir = Path(annotation_dir) + metadata = pd.read_csv(annotation_dir / "metadata.csv") + metadata["cell_id"] = metadata["cell_id"].astype(str) + + ann_path = annotation_dir / "annotations.json" + if not ann_path.exists(): + raise FileNotFoundError( + f"No annotations found at {ann_path}. Run scripts/annotate.py first." + ) + with open(ann_path) as f: + annotations = json.load(f) + + metadata["is_correct"] = metadata["cell_id"].map(annotations) + return metadata[metadata["is_correct"].notna()].copy() + + +def export_cell_panels( + sdata, + annotation_dir: str | Path, + cell_type_key: str, + cell_type_of_interest: str, + target_proportion: float, + n_cells: int = 600, + image_channel: str = "DAPI", + delta: float = 50.0, + n_top_genes: int = 15, + layer: str | None = None, + gene_colors: dict[str, str] | None = None, + seed: int = 0, + dpi: int = 150, +) -> pd.DataFrame: + """ + Subsample cells from a SpatialData object and export per-cell annotation panels. + + Each panel is a side-by-side figure: + left — spatial context: fluorescence image + cell boundaries + top-gene transcripts + right — horizontal bar chart of the cell's top expressed genes + + A ``metadata.csv`` is written to ``annotation_dir`` with one row per exported cell, + containing: cell_id, cell_type, image_path, sampling_weight, center_x, center_y. + + Parameters + ---------- + sdata + SpatialData object. Must have a ``"table"`` AnnData whose obs contains + ``cell_type_key``, ``center_x``, and ``center_y``. + annotation_dir + Root output directory. Images are written to ``annotation_dir/images/``. + cell_type_key + obs column that holds cell-type labels. + cell_type_of_interest + Cell type to oversample (e.g. ``"macrophages"``). + target_proportion + Desired fraction of ``cell_type_of_interest`` cells in the subsample, in (0, 1). + n_cells + Total number of cells to sample. + image_channel + Fluorescence channel used as image background (e.g. ``"DAPI"``, ``"Cellbound2"``). + delta + Half-width of the bounding box around each cell center, in microns. + n_top_genes + Number of top-expressed genes shown in the right panel. + layer + AnnData layer to use for expression values; falls back to X when absent or None. + gene_colors + Mapping of gene name → colour. Left panel shows only these genes (with their + colours); right panel uses the colour for keyed genes and grey for the rest. + seed + Random seed for reproducibility. + dpi + Resolution of saved images. + + Returns + ------- + pd.DataFrame + The metadata table also written to ``annotation_dir/metadata.csv``. + """ + try: + from tqdm import tqdm + except ImportError: + def tqdm(it, **_): + return it + + annotation_dir = Path(annotation_dir) + images_dir = annotation_dir / "images" + images_dir.mkdir(parents=True, exist_ok=True) + + adata = sdata["table"] + adata = adata[adata.obs[cell_type_key].notna()].copy() + sub = subsample_cells( + adata, + cell_type_key=cell_type_key, + cell_type_of_interest=cell_type_of_interest, + target_proportion=target_proportion, + n_cells=n_cells, + seed=seed, + ) + + metadata_rows = [] + for cell_id in tqdm(sub.obs_names, desc="Exporting cells"): + obs = sub.obs.loc[cell_id] + cx, cy = float(obs["center_x"]), float(obs["center_y"]) + cell_type = obs[cell_type_key] + + cell_adata = sub[[cell_id]] + + # --- left panel: spatial context --- + ax_left = plot_region( + sdata, + xmin=cx - delta, + xmax=cx + delta, + ymin=cy - delta, + ymax=cy + delta, + gene_colors=gene_colors or {}, + image_channel=image_channel, + coord_system="micron", + figsize=(7, 7), + ) + buf = io.BytesIO() + ax_left.figure.savefig(buf, format="png", dpi=dpi, bbox_inches="tight") + buf.seek(0) + plt.close(ax_left.figure) + + img_left = plt.imread(buf) + + # --- combined figure --- + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + axes[0].imshow(img_left) + axes[0].axis("off") + axes[0].set_title(f"{cell_type} | {cell_id}", fontsize=8) + + plot_top_genes( + cell_adata, + n_genes=n_top_genes, + ax=axes[1], + layer=layer, + title="Top expressed genes", + gene_colors=gene_colors, + ) + + fig.tight_layout() + img_path = images_dir / f"cell_{cell_id}.png" + fig.savefig(img_path, dpi=dpi, bbox_inches="tight") + plt.close(fig) + + metadata_rows.append( + { + "cell_id": cell_id, + "cell_type": cell_type, + "image_path": str(img_path), + "sampling_weight": float(obs["sampling_weight"]), + "center_x": cx, + "center_y": cy, + } + ) + pd.DataFrame(metadata_rows).to_csv(annotation_dir / "metadata.csv", index=False) + + return pd.DataFrame(metadata_rows) diff --git a/src/csde/api.py b/src/csde/api.py index edfcae5..a008b67 100644 --- a/src/csde/api.py +++ b/src/csde/api.py @@ -4,7 +4,8 @@ import numpy as np import pandas as pd -from csde.model import InterceptRegression +from csde.model_poisson import PoissonIntercept +from csde.model_nb import NBIntercept def _map_cell_types( @@ -45,6 +46,8 @@ def run_csde( cell_pop_b: str, gt_key: str, layer_name: Optional[str] = None, + importance_weights: Optional[np.ndarray] = None, + noise_model: str = "poisson", **model_kwargs, ) -> pd.DataFrame: """ @@ -61,7 +64,10 @@ def run_csde( cell_pop_b: Name of the second cell population (target group). gt_key: Boolean column in adata_gt.obs indicating if the prediction is correct. layer_name: Layer in adata.layers to use for expression counts. If None, uses .X. - **model_kwargs: Additional arguments passed to InterceptRegression (e.g., family, optimizer). + importance_weights: Optional 1-D array of importance weights for the ground-truth + observations. Will be normalized to sum to n_obs internally. + noise_model: Noise model to use. Either "poisson" or "nb". + **model_kwargs: Additional arguments passed to PoissonIntercept (e.g., optimizer). Returns: DataFrame indexed by gene names with columns: @@ -106,12 +112,24 @@ def get_X(adata): inputs_unl = (X_unl, y_pred_unl) # inference - model = InterceptRegression( + if noise_model == "poisson": + model = PoissonIntercept( inputs_gt=inputs_gt, inputs_hat=inputs_hat, inputs_unl=inputs_unl, + importance_weights=importance_weights, **model_kwargs, ) + elif noise_model == "nb": + model = NBIntercept( + inputs_gt=inputs_gt, + inputs_hat=inputs_hat, + inputs_unl=inputs_unl, + importance_weights=importance_weights, + **model_kwargs, + ) + else: + raise ValueError(f"Unknown noise model: {noise_model}") model.fit(lambd_=None) model.get_asymptotic_distribution() diff --git a/src/csde/model.py b/src/csde/model_poisson.py similarity index 57% rename from src/csde/model.py rename to src/csde/model_poisson.py index 1979849..b221437 100644 --- a/src/csde/model.py +++ b/src/csde/model_poisson.py @@ -5,170 +5,21 @@ import jax.numpy as jnp import numpy as np import pandas as pd -from numpyro.distributions import Normal, Poisson +from numpyro.distributions import Poisson from statsmodels.stats.multitest import multipletests from tqdm import tqdm +from csde._base import PPIAbstractClass from csde.optimization import _zstat_generic2, optimize_ppi, optimize_ppi_gd -jax.config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", False) -class PPIAbstractClass: - """ - Abstract base class for Prediction-Powered Inference (PPI) models. - """ - - def __init__( - self, - inputs_gt: Union[Tuple[np.ndarray, np.ndarray], np.ndarray], - inputs_hat: Union[Tuple[np.ndarray, np.ndarray], np.ndarray], - inputs_unl: Union[Tuple[np.ndarray, np.ndarray], np.ndarray], - lambd_mode: str = "overall", - ): - """ - Initialize the PPI model. - - Args: - inputs_gt: Ground truth data (features, labels). - inputs_hat: Predicted data for the labeled set (features, predicted labels). - inputs_unl: Unlabeled data (features, predicted labels). - lambd_mode: Mode for lambda parameter ('overall' or 'element'). - """ - self.inputs_gt = inputs_gt - self.inputs_hat = inputs_hat - self.inputs_unl = inputs_unl - - inputs_are_tuples = isinstance(inputs_gt, tuple) - if inputs_are_tuples: - self.n = inputs_gt[0].shape[0] - self.N = inputs_unl[0].shape[0] - else: - self.n = self.inputs_gt.shape[0] - self.N = self.inputs_unl.shape[0] - self.r = float(self.n) / self.N - self.theta = None - self.sigma = None - self.hessian = None - self.v = None - self.lambd_mode = lambd_mode - self.lambd_ = None - - def get_asymptotic_distribution(self) -> Tuple[np.ndarray, np.ndarray]: - """ - Compute the asymptotic distribution of the estimator. - - Returns: - Tuple containing the point estimate (theta) and the covariance matrix (sigma). - """ - self.sigma = self.compute_sigma(self.lambd_) - return self.theta, self.sigma - - def compute_sigma(self, lambd: Union[float, np.ndarray]) -> np.ndarray: - """ - Compute the covariance matrix of the estimator. - """ - grad_f_unl = self.grad_fn(self.inputs_unl) - grad_f_hat = self.grad_fn(self.inputs_hat) - grad_f_all = np.vstack([grad_f_hat, grad_f_unl]) - grad_f_gt = self.grad_fn(self.inputs_gt) - - grad_f_ = grad_f_all - grad_f_all.mean(axis=0) - - # Handle lambda broadcasting if necessary - if self.lambd_mode == "element" and isinstance(lambd, np.ndarray): - # This part was implementation specific in subclass, but generalized here based on pattern - # Assuming lambd matches gradient dimensions or is handled in subclass override - pass - - # Base implementation for scalar lambda, override in subclass if needed - vf = (lambd**2) * (grad_f_.T @ grad_f_) / (self.n + self.N) - rect_ = grad_f_gt - lambd * grad_f_hat - rect_ = rect_ - rect_.mean(axis=0) - vdelta = (rect_.T @ rect_) / self.n - v = vdelta + (self.r * vf) - - hess = self.hessian_fn(self.inputs_gt) - self.hessian = hess - self.v = v - return self._compute_sigma(hess, v, self.n) - - @staticmethod - def _compute_sigma(hess: np.ndarray, v: np.ndarray, n: int) -> np.ndarray: - """ - Compute the asymptotic covariance matrix of the parameter estimates. - """ - inv_hess = np.linalg.pinv(hess) - sigma_ = inv_hess @ v @ inv_hess - sigma_ = sigma_ / n - return sigma_ - - def get_lambda( - self, - lambd_0: float = 0.5, - idx_to_optimize: Optional[Union[int, List[int]]] = None, - ) -> Union[float, np.ndarray]: - """ - Estimate the optimal lambda parameter. - """ - print("get point estimate ...") - self.theta = self.get_pointestimate(lambd_=lambd_0) - print("done") - - hess = self.hessian_fn(self.inputs_gt) - - inv_hess = np.linalg.pinv(hess) - grad_f_unl = self.grad_fn(self.inputs_unl) - grad_f_hat = self.grad_fn(self.inputs_hat) - grad_f_all = np.vstack([grad_f_hat, grad_f_unl]) - grad_f_gt = self.grad_fn(self.inputs_gt) - - grad_f_hat_ = grad_f_hat - grad_f_hat.mean(0) - grad_f_gt_ = grad_f_gt - grad_f_gt.mean(0) - cov1 = (grad_f_hat_.T @ grad_f_gt_) / self.n - cov2 = (grad_f_gt_.T @ grad_f_hat_) / self.n - - grad_f_ = grad_f_all - grad_f_all.mean(axis=0) - vf = (grad_f_.T @ grad_f_) / (self.n + self.N) - num = inv_hess @ (cov1 + cov2) @ inv_hess - denom = 2 * (1.0 + self.r) * (inv_hess @ vf @ inv_hess) - if self.lambd_mode == "element": - lambd_star = num / denom - return np.diag(lambd_star) - elif idx_to_optimize is not None: - print("optimize lambda for a single theta comp.") - if isinstance(idx_to_optimize, int): - return ( - num[idx_to_optimize, idx_to_optimize] - / denom[idx_to_optimize, idx_to_optimize] - ) - else: - return np.trace(num[idx_to_optimize, :][:, idx_to_optimize]) / np.trace( - denom[idx_to_optimize, :][:, idx_to_optimize] - ) - else: - return np.trace(num) / np.trace(denom) - - def get_pointestimate(self, lambd_: float) -> np.ndarray: - raise NotImplementedError - - def grad_fn(self, inputs: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: - raise NotImplementedError - - def hessian_fn(self, inputs: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: - raise NotImplementedError - - -class RegressionInterceptModel(nn.Module): - """ - Flax module for the intercept regression model. - """ - +class PoissonInterceptModule(nn.Module): n_classes: int n_features: int mu_prior_std: Union[float, jnp.ndarray] n_obs_real: int - family: str def setup(self): self.mu0 = self.param("mu0", nn.initializers.normal(), (self.n_features)) @@ -176,55 +27,37 @@ def setup(self): "mu", nn.initializers.normal(), (self.n_classes - 1, self.n_features) ) - def __call__(self, x, y): + def __call__(self, x, y, w=None): y_ = y.astype(jnp.int32) mu_placeholder = jnp.zeros_like(self.mu0) mu = jnp.concatenate([mu_placeholder[None], self.mu], axis=0) y_oh = jnp.eye(self.n_classes)[y_] mus_ = y_oh @ mu + self.mu0 - if self.family == "poisson": - rates = jnp.exp(mus_) - log_px_c_unsummed = Poisson(rate=rates).log_prob(x) - log_px_c = log_px_c_unsummed.sum(axis=-1) - elif self.family == "gaussian": - log_px_c_unsummed = Normal(loc=mus_, scale=1.0).log_prob(x) - log_px_c = log_px_c_unsummed.sum(axis=-1) - else: - raise ValueError(f"Unknown family: {self.family}") + if w is None: + w = jnp.ones_like(y, dtype=jnp.float64) + + rates = jnp.exp(mus_) + log_px_c_unsummed = Poisson(rate=rates).log_prob(x) + log_px_c = log_px_c_unsummed.sum(axis=-1) loss = -log_px_c return { - "loss": loss, - "loss_unsummed": -log_px_c_unsummed, + "loss": loss * w, + "loss_unsummed": -log_px_c_unsummed * w[..., None], } -class InterceptRegression(PPIAbstractClass): - """ - Intercept Regression model for spatial differential expression analysis. - """ - +class PoissonIntercept(PPIAbstractClass): def __init__( self, mu_prior_std: Optional[Union[float, jnp.ndarray]] = None, optimizer: str = "gd", optimizer_kwargs: Optional[Dict[str, Any]] = None, - family: str = "poisson", jit: bool = True, + importance_weights: Optional[np.ndarray] = None, **kwargs, ): - """ - Initialize the InterceptRegression model. - - Args: - mu_prior_std: Prior standard deviation for mu. - optimizer: Optimization method ('gd' or 'lbfgs'). - optimizer_kwargs: Keyword arguments for the optimizer. - family: Distribution family ('poisson' or 'gaussian'). - jit: Whether to JIT compile the optimization. - **kwargs: Arguments passed to PPIAbstractClass (inputs_gt, inputs_hat, inputs_unl). - """ super().__init__(**kwargs) x_gt, y_gt = self.inputs_gt @@ -239,16 +72,27 @@ def __init__( self.inputs_gt = (x_gt, y_gt) self.inputs_hat = (x_hat, y_hat) self.inputs_unl = (x_unl, y_unl) + + if importance_weights is not None: + if importance_weights.shape != (x_gt.shape[0],): + raise ValueError( + "importance_weights must be a 1-D array with the same length " + "as the number of ground-truth observations" + ) + w = float(x_gt.shape[0]) * importance_weights / importance_weights.sum() + self.importance_weights = w + else: + self.importance_weights = None + n_obs_real = x_gt.shape[0] self.n_features = x_gt.shape[1] self.n_params = (self.n_classes - 1) * self.n_features + self.n_features - self.model = RegressionInterceptModel( + self.model = PoissonInterceptModule( n_classes=self.n_classes, n_features=self.n_features, mu_prior_std=mu_prior_std, n_obs_real=n_obs_real, - family=family, ) self.model_params = None @@ -262,13 +106,6 @@ def __init__( def fit( self, lambd_: Optional[Union[float, np.ndarray]] = None, refit: bool = False ): - """ - Fit the model parameters. - - Args: - lambd_: Lambda parameter. If None, it is estimated. - refit: Whether to re-initialize parameters before fitting. - """ if lambd_ is None: lambd_ = self.get_lambda() print(f"lambda: {lambd_}") @@ -282,25 +119,18 @@ def get_lambda( lambd_0: float = 0.5, idx_to_optimize: Optional[Union[int, List[int]]] = None, ) -> Union[float, np.ndarray]: - """ - Estimate the optimal lambda parameter. - Overriding parent method to handle element-wise lambda specific logic. - """ - # Call parent to get num and denom matrices/values if needed, but the parent implementation - # might need access to _construct_contrast for element-wise which is specific to this class. - # So copying logic from original implementation to be safe and consistent. - print("get point estimate ...") self.theta = self.get_pointestimate(lambd_=lambd_0) print("done") - hess = self.hessian_fn(self.inputs_gt) - + hess = self.hessian_fn( + self.inputs_gt, importance_weights=self.importance_weights + ) inv_hess = np.linalg.pinv(hess) grad_f_unl = self.grad_fn(self.inputs_unl) - grad_f_hat = self.grad_fn(self.inputs_hat) + grad_f_hat = self.grad_fn(self.inputs_hat, w=self.importance_weights) grad_f_all = np.vstack([grad_f_hat, grad_f_unl]) - grad_f_gt = self.grad_fn(self.inputs_gt) + grad_f_gt = self.grad_fn(self.inputs_gt, w=self.importance_weights) grad_f_hat_ = grad_f_hat - grad_f_hat.mean(0) grad_f_gt_ = grad_f_gt - grad_f_gt.mean(0) @@ -342,6 +172,7 @@ def get_pointestimate(self, lambd_: Union[float, np.ndarray]) -> np.ndarray: model_params0 = self.model_params if self.model_params is not None else None if self.optimizer == "lbfgs": + print("optimize with lbfgs") model_params = optimize_ppi( self.model, lambd_=lambd_, @@ -355,6 +186,7 @@ def get_pointestimate(self, lambd_: Union[float, np.ndarray]) -> np.ndarray: **self.optimizer_kwargs, ) elif self.optimizer == "gd": + print("optimize with gd") model_params = optimize_ppi_gd( self.model, lambd_=lambd_, @@ -364,6 +196,7 @@ def get_pointestimate(self, lambd_: Union[float, np.ndarray]) -> np.ndarray: y_hat=y_hat, x_unl=x_unl, y_unl=y_unl, + w=self.importance_weights, model_params0=model_params0, **self.optimizer_kwargs, ) @@ -377,11 +210,10 @@ def get_pointestimate(self, lambd_: Union[float, np.ndarray]) -> np.ndarray: return np.hstack([mu, mu0]) def compute_sigma(self, lambd: Union[float, np.ndarray]) -> np.ndarray: - # Override to handle element-wise lambda and specific broadcasting grad_f_unl = self.grad_fn(self.inputs_unl) - grad_f_hat = self.grad_fn(self.inputs_hat) + grad_f_hat = self.grad_fn(self.inputs_hat, w=self.importance_weights) grad_f_all = np.vstack([grad_f_hat, grad_f_unl]) - grad_f_gt = self.grad_fn(self.inputs_gt) + grad_f_gt = self.grad_fn(self.inputs_gt, w=self.importance_weights) grad_f_ = grad_f_all - grad_f_all.mean(axis=0) if self.lambd_mode == "element": @@ -396,36 +228,41 @@ def compute_sigma(self, lambd: Union[float, np.ndarray]) -> np.ndarray: vdelta = (rect_.T @ rect_) / self.n v = vdelta + (self.r * vf) - hess = self.hessian_fn(self.inputs_gt) + hess = self.hessian_fn( + self.inputs_gt, importance_weights=self.importance_weights + ) self.hessian = hess self.v = v return self._compute_sigma(hess, v, self.n) def grad_fn( - self, inputs: Tuple[np.ndarray, np.ndarray], batch_size: int = 128 + self, + inputs: Tuple[np.ndarray, np.ndarray], + w: Optional[np.ndarray] = None, + batch_size: int = 128, ) -> np.ndarray: x, y = inputs n_obs = x.shape[0] - def likelihood(model_params, x, y): - return self.model.apply(model_params, x, y)["loss"] + def likelihood(model_params, x, y, w=None): + return self.model.apply(model_params, x, y, w=w)["loss"] + score = self.jit(jax.jacfwd(likelihood)) all_grads = np.zeros((n_obs, self.n_params)) for i in tqdm(range(0, n_obs, batch_size), desc="Gradient computation"): - x_batch = x[i:i+batch_size] - y_batch = y[i:i+batch_size] + x_batch = x[i : i + batch_size] + y_batch = y[i : i + batch_size] + w_batch = w[i : i + batch_size] if w is not None else None n_obs_batch = x_batch.shape[0] - score = self.jit(jax.jacfwd(likelihood)) - grads = score(self.model_params, x_batch, y_batch) + grads = score(self.model_params, x_batch, y_batch, w=w_batch) grad_mu = np.array(grads["params"]["mu"].reshape(n_obs_batch, -1)) grad_mu0 = np.array(grads["params"]["mu0"].reshape(n_obs_batch, -1)) - all_grads[i:i+batch_size] = np.hstack([grad_mu, grad_mu0]) + all_grads[i : i + batch_size] = np.hstack([grad_mu, grad_mu0]) return np.array(all_grads) def _construct_contrast(self, feature_id: int, idx_a: int) -> np.ndarray: mu_contrast = np.zeros((self.n_classes - 1, self.n_features)) mu_contrast[idx_a - 1, feature_id] = 1.0 - mu0_contrast = np.zeros(self.n_features) contrast = np.hstack([mu_contrast.flatten(), mu0_contrast]) return contrast.astype(int) @@ -433,7 +270,6 @@ def _construct_contrast(self, feature_id: int, idx_a: int) -> np.ndarray: def idx_to_feat(self) -> np.ndarray: mu_identifier = np.ones((self.n_classes - 1, self.n_features)) mu_identifier = mu_identifier * np.arange(self.n_features) - mu0_identifier = np.arange(self.n_features) identifier = np.hstack([mu_identifier.flatten(), mu0_identifier]) return identifier.astype(int) @@ -443,13 +279,11 @@ def construct_contrast(self, idx_a: int) -> np.ndarray: self._construct_contrast(feature_id, idx_a) for feature_id in range(self.n_features) ] - _contrast = np.vstack(_contrast) - return _contrast + return np.vstack(_contrast) def get_beta(self, idx_a: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: if idx_a == 0: raise ValueError("`class_a` cannot be the reference class.") - contrast = self.construct_contrast(idx_a) beta = contrast @ self.theta cov = contrast @ self.sigma @ contrast.T @@ -473,8 +307,7 @@ def _get_param_mask(self, feature_id: int) -> np.ndarray: for class_id in range(1, self.n_classes) ] mu0_indices = [self._get_param_id(feature_id=feature_id, param_type="mu0")] - indices_to_keep = np.hstack([mu_indices, mu0_indices]) - return indices_to_keep + return np.hstack([mu_indices, mu0_indices]) def test_differential_expression( self, @@ -482,17 +315,6 @@ def test_differential_expression( feature_names: Optional[List[str]] = None, cond_thresh: float = np.inf, ) -> pd.DataFrame: - """ - Perform differential expression testing. - - Args: - idx_a: The index of the target class (1 or 2). Note: 0 is the reference class. - feature_names: List of feature names. - cond_thresh: Condition number threshold for Hessian. - - Returns: - DataFrame containing the results (p-values, log-fold changes, etc.). - """ idx_a_ = idx_a - 1 results = [] for feature_id in range(self.n_features): @@ -522,7 +344,7 @@ def test_differential_expression( } ) res = pd.DataFrame(results) - res["pval"].iloc[np.isnan(res["pval"])] = 1.0 + res.loc[np.isnan(res["pval"]), "pval"] = 1.0 res["padj"] = multipletests(res["pval"], method="fdr_bh")[1] res["is_significant_005"] = res["padj"] < 0.05 if feature_names is not None: @@ -541,7 +363,10 @@ def zero_init(self): self.model_params = params def hessian_fn( - self, inputs: Tuple[np.ndarray, np.ndarray], device=None + self, + inputs: Tuple[np.ndarray, np.ndarray], + importance_weights: Optional[np.ndarray] = None, + device=None, ) -> np.ndarray: x, y = inputs @@ -552,13 +377,13 @@ def hessian_fn( obs_ids = np.arange(n_obs) model_ = self.model - def likelihood(model_params, x, y): - return model_.apply(model_params, x, y)["loss"] + def likelihood(model_params, x, y, w=None): + return model_.apply(model_params, x, y, w=w)["loss"] hess_fn = jax.hessian(likelihood) - def process_hess(x, y): - hess_ = hess_fn(model_params_, x, y) + def process_hess(x, y, w=None): + hess_ = hess_fn(model_params_, x, y, w=w) mu_mu = ( hess_["params"]["mu"]["params"]["mu"] .mean(0) @@ -577,13 +402,12 @@ def process_hess(x, y): .mean(0) .reshape(self.n_features, self.n_features) ) - blk = jnp.block( + return jnp.block( [ [mu_mu, mu_mu0], [mu_mu0.T, mu0_mu0], ] ) - return blk hessian = np.zeros((self.n_params, self.n_params), dtype=np.float64) for obs_id in tqdm(obs_ids, desc="Hessian computation"): @@ -591,6 +415,10 @@ def process_hess(x, y): y_ = jnp.array(y[[obs_id]], dtype=jnp.int32) x_obs = jax.device_put(x_, device) y_obs = jax.device_put(y_, device) - hess_ = process_hess(x_obs, y_obs) - hessian += hess_ / float(n_obs) + if importance_weights is not None: + w_ = jnp.array(importance_weights[[obs_id]], dtype=jnp.float64) + w_obs = jax.device_put(w_, device) + else: + w_obs = None + hessian += process_hess(x_obs, y_obs, w_obs) / float(n_obs) return hessian diff --git a/src/csde/optimization.py b/src/csde/optimization.py index 9269d1b..e7ffa02 100644 --- a/src/csde/optimization.py +++ b/src/csde/optimization.py @@ -102,6 +102,7 @@ def optimize_ppi_gd( y_hat: jnp.ndarray, x_unl: jnp.ndarray, y_unl: jnp.ndarray, + w: Optional[jnp.ndarray] = None, model_params0: Optional[Any] = None, lambd_: float = 1.0, tol: float = 1e-3, @@ -123,6 +124,8 @@ def optimize_ppi_gd( y_hat = jax.device_put(jnp.array(y_hat, dtype=jnp.int32)) x_unl = jax.device_put(jnp.array(x_unl, dtype=jnp.float64)) y_unl = jax.device_put(jnp.array(y_unl, dtype=jnp.int32)) + if w is not None: + w = jax.device_put(jnp.array(w, dtype=jnp.float64)) x0 = jnp.ones((32, x_gt.shape[1]), dtype=jnp.float64) y0 = jnp.ones(32, dtype=jnp.int32) @@ -142,23 +145,35 @@ def optimize_ppi_gd( lambd_ = jax.device_put(lambd_) - def loss_fn(zetas): - loss_gt = model.apply(zetas, x_gt, y_gt)["loss_unsummed"].mean(0) - loss_hat = model.apply(zetas, x_hat, y_hat)["loss_unsummed"].mean(0) - loss_unl = model.apply(zetas, x_unl, y_unl)["loss_unsummed"].mean(0) - loss = (lambd_ * loss_unl) - (lambd_ * loss_hat) + loss_gt - loss = loss.sum(-1) - return loss + def step_fn(theta_, opt_state_, x_gt_, y_gt_, x_hat_, y_hat_, x_unl_, y_unl_): + def loss_fn(zetas): + loss_gt = model.apply(zetas, x_gt_, y_gt_, w=w)["loss_unsummed"].mean(0) + loss_hat = model.apply(zetas, x_hat_, y_hat_, w=w)["loss_unsummed"].mean(0) + loss_unl = model.apply(zetas, x_unl_, y_unl_)["loss_unsummed"].mean(0) + loss = (lambd_ * loss_unl) - (lambd_ * loss_hat) + loss_gt + loss = loss.sum(-1) + + # loss_gt = model.apply(zetas, x_gt_, y_gt_, w=w)["loss"].mean() + # loss_hat = model.apply(zetas, x_hat_, y_hat_, w=w)["loss"].mean() + # loss_unl = model.apply(zetas, x_unl_, y_unl_)["loss"].mean() + # loss = (lambd_ * loss_unl) - (lambd_ * loss_hat) + loss_gt + return loss + + loss, grad = jax.value_and_grad(loss_fn)(theta_) + updates, opt_state_ = opt.update(grad, opt_state_, theta_) + theta_ = optax.apply_updates(theta_, updates) + return theta_, opt_state_, loss + + compiled_step = jitter(step_fn) - value_and_grad_fn = jitter(jax.value_and_grad(loss_fn)) previous_loss = 1e6 print("lambda:", lambd_) print("tol:", tol_) pbar = trange(n_iter) for _ in pbar: - loss, grad = value_and_grad_fn(theta) - updates, opt_state = opt.update(grad, opt_state, theta) - theta = optax.apply_updates(theta, updates) + theta, opt_state, loss = compiled_step( + theta, opt_state, x_gt, y_gt, x_hat, y_hat, x_unl, y_unl + ) stopping_criterion = np.abs(loss - previous_loss) if np.allclose(loss, previous_loss, atol=tol_, rtol=0): diff --git a/src/csde/spatial_utils.py b/src/csde/spatial_utils.py new file mode 100644 index 0000000..d3e4e46 --- /dev/null +++ b/src/csde/spatial_utils.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +from typing import Literal + +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + + +def _micron_to_global(sdata, xmin: float, ymin: float, xmax: float, ymax: float): + """Convert bounding-box corners from micron (intrinsic) to global (pixel) space.""" + from spatialdata.transformations import get_transformation + + shapes_key = list(sdata.shapes.keys())[0] + t = get_transformation(sdata.shapes[shapes_key], to_coordinate_system="global") + m = t.matrix # 3×3 affine; for MERSCOPE this is a pure scale+translation + + # Transform all four corners and take the bounding box (handles negative scales). + xs = [m[0, 0] * x + m[0, 2] for x in (xmin, xmax)] + ys = [m[1, 1] * y + m[1, 2] for y in (ymin, ymax)] + return min(xs), min(ys), max(xs), max(ys) + + +def plot_region( + sdata, + xmin: float, + ymin: float, + xmax: float, + ymax: float, + gene_colors: dict[str, str], + image_channel: str, + coord_system: Literal["global", "micron"] = "global", + image_cmap: str = "gray", + outline_color: str = "white", + point_size: int = 2, + figsize: tuple = (14, 14), +): + """ + Plot a spatial region: fluorescence image + cell boundaries + transcript dots. + + Parameters + ---------- + sdata + SpatialData object. + xmin, ymin, xmax, ymax + Bounding-box corners. Interpreted according to ``coord_system``. + gene_colors + Mapping of gene name → hex/named colour. Only these genes are shown + as transcript dots, using their assigned colours. + image_channel + Fluorescence channel name to use as background (e.g. ``"DAPI"``, + ``"Cellbound1"``). + coord_system + ``"global"`` – coordinates are in mosaic-pixel space (the SpatialData + global coordinate system). + ``"micron"`` – coordinates are in physical microns, i.e. the same units + as ``center_x`` / ``center_y`` stored in ``adata.obs``. They are + converted to pixel space before querying. + image_cmap + Colormap for the fluorescence background. + outline_color + Colour for cell-boundary outlines. + point_size + Size of transcript dots. + figsize + Figure size passed to matplotlib. + """ + import spatialdata_plot # noqa: F401 — registers the .pl accessor on SpatialData + + shapes_key = list(sdata.shapes.keys())[0] + image_key = list(sdata.images.keys())[0] + points_key = list(sdata.points.keys())[0] + + if coord_system == "micron": + xmin, ymin, xmax, ymax = _micron_to_global(sdata, xmin, ymin, xmax, ymax) + elif coord_system != "global": + raise ValueError(f"coord_system must be 'global' or 'micron', got {coord_system!r}") + + cropped = sdata.query.bounding_box( + min_coordinate=[xmin, ymin], + max_coordinate=[xmax, ymax], + axes=("x", "y"), + target_coordinate_system="global", + ) + + genes = list(gene_colors.keys()) + palette = list(gene_colors.values()) + + renderer = ( + cropped.pl + .render_images(image_key, channel=image_channel, cmap=image_cmap) + .pl.render_shapes( + shapes_key, + fill_alpha=0.0, + outline_alpha=0.8, + outline_color=outline_color, + outline_width=0.5, + ) + ) + if genes: + renderer = renderer.pl.render_points( + points_key, + color="gene", + groups=genes, + palette=palette, + size=point_size, + alpha=0.9, + ) + ax = renderer.pl.show( + figsize=figsize, + title=f"Region [{xmin:.0f}-{xmax:.0f}, {ymin:.0f}-{ymax:.0f}]", + return_ax=True, + ) + return ax + + +def compute_importance_weights( + adata, + cell_type_key: str, + cell_type_of_interest: str, + target_proportion: float, +) -> np.ndarray: + """ + Return unnormalized per-cell weights so that after sampling the fraction of + `cell_type_of_interest` cells equals `target_proportion`. + + Cells of interest receive weight w; all others receive weight 1. + Derived from: n_int * w / (n_int * w + n_other) = target_proportion. + """ + is_interest = adata.obs[cell_type_key] == cell_type_of_interest + n_interest = int(is_interest.sum()) + n_other = int((~is_interest).sum()) + + if n_interest == 0: + return np.ones(len(adata)) + if not (0.0 < target_proportion < 1.0): + raise ValueError("target_proportion must be in (0, 1)") + + w = target_proportion * n_other / (n_interest * (1.0 - target_proportion)) + return np.where(is_interest.values, w, 1.0) + + +def subsample_cells( + adata, + cell_type_key: str, + cell_type_of_interest: str, + target_proportion: float, + n_cells: int, + seed: int = 0, +): + """ + Sample `n_cells` from `adata` without replacement using importance weights + that target `target_proportion` cells of `cell_type_of_interest`. + + The unnormalized sampling weight is stored in ``obs["sampling_weight"]`` + of the returned AnnData (needed downstream for weighted estimation). + """ + weights = compute_importance_weights( + adata, cell_type_key, cell_type_of_interest, target_proportion + ) + norm_weights = weights / weights.sum() + rng = np.random.default_rng(seed) + indices = rng.choice(len(adata), size=n_cells, p=norm_weights, replace=False) + sub = adata[indices].copy() + sub.obs["sampling_weight"] = weights[indices] + return sub + + +def plot_top_genes( + adata_cell, + n_genes: int = 15, + ax=None, + layer: str | None = None, + title: str | None = None, + gene_colors: dict[str, str] | None = None, +): + """ + Horizontal bar chart of the top `n_genes` expressed genes for a single cell. + + Parameters + ---------- + adata_cell + AnnData slice for one cell (shape 1 × n_vars). + n_genes + Number of top genes to display. + ax + Existing matplotlib axes; created if None. + layer + Layer to use for expression values; falls back to X if absent or None. + title + Axes title. + """ + if ax is None: + _, ax = plt.subplots(figsize=(5, 5)) + + if layer is not None and layer in adata_cell.layers: + x = adata_cell.layers[layer] + else: + x = adata_cell.X + + if hasattr(x, "toarray"): + x = x.toarray() + x = np.asarray(x).flatten() + + expr = ( + pd.Series(x, index=adata_cell.var_names) + .sort_values(ascending=False) + .head(n_genes) + .iloc[::-1] # highest bar at the top + ) + + colors = ( + [gene_colors.get(g, "grey") for g in expr.index] + if gene_colors + else ["steelblue"] * len(expr) + ) + ax.barh(expr.index, expr.values, color=colors) + ax.set_xlabel("counts") + ax.tick_params(axis="y", labelsize=8) + if title is not None: + ax.set_title(title) + return ax diff --git a/tests/test_csde.py b/tests/test_csde.py index 54c06d7..4a63e41 100644 --- a/tests/test_csde.py +++ b/tests/test_csde.py @@ -1,7 +1,9 @@ import unittest + +import anndata import numpy as np import pandas as pd -import anndata + from csde import run_csde @@ -53,6 +55,71 @@ def test_run_csde(self): ) self.assertTrue(not res.isnull().values.any()) + def test_run_csde_with_importance_weights(self): + n_gt = len(self.adata_gt) + rng = np.random.default_rng(0) + importance_weights = rng.uniform(0.5, 2.0, size=n_gt) + + res = run_csde( + adata_pred=self.adata_pred, + adata_gt=self.adata_gt, + pred_cell_pop_key="cell_type", + cell_pop_a="TypeA", + cell_pop_b="TypeB", + gt_key="is_correct", + optimizer="gd", + optimizer_kwargs={"n_iter": 10}, + importance_weights=importance_weights, + noise_model="poisson", + ) + + self.assertIsInstance(res, pd.DataFrame) + self.assertEqual(len(res), 10) + self.assertListEqual( + list(res.columns), ["log_fold_change", "p_value", "p_value_adj"] + ) + self.assertTrue(not res.isnull().values.any()) + + res = run_csde( + adata_pred=self.adata_pred, + adata_gt=self.adata_gt, + pred_cell_pop_key="cell_type", + cell_pop_a="TypeA", + cell_pop_b="TypeB", + gt_key="is_correct", + optimizer="gd", + optimizer_kwargs={"n_iter": 10}, + importance_weights=importance_weights, + noise_model="nb", + ) + + self.assertIsInstance(res, pd.DataFrame) + self.assertEqual(len(res), 10) + self.assertListEqual( + list(res.columns), ["log_fold_change", "p_value", "p_value_adj"] + ) + self.assertTrue(not res.isnull().values.any()) + + def test_importance_weights_wrong_shape(self): + from csde.model_poisson import PoissonIntercept as InterceptRegression + + x_gt, y_gt = self.adata_gt.X.astype(float), np.zeros( + len(self.adata_gt), dtype=int + ) + x_hat = x_gt.copy() + x_unl = self.adata_pred.X.astype(float) + y_hat = np.zeros(len(self.adata_gt), dtype=int) + y_unl = np.zeros(len(self.adata_pred), dtype=int) + + bad_weights = np.ones(len(self.adata_gt) + 5) + with self.assertRaises(ValueError): + InterceptRegression( + inputs_gt=(x_gt, y_gt), + inputs_hat=(x_hat, y_hat), + inputs_unl=(x_unl, y_unl), + importance_weights=bad_weights, + ) + if __name__ == "__main__": unittest.main()