Source code for scCS.drivers

"""
drivers.py — Driver gene identification for scCS fate arms.

Two complementary strategies:

1. Velocity-based drivers
   For each fate arm, rank genes by their mean scVelo velocity in arm cells.
   High positive velocity = gene is being actively upregulated along that fate.
   Requires the 'velocity' layer (from scVelo pipeline).

2. DEG-based drivers
   For each fate arm, run a Wilcoxon rank-sum test comparing arm cells vs
   the bifurcation (progenitor) cluster.  Returns logFC and adjusted p-value
   per gene, with a significance flag.

Both functions operate on adata_sub (the subset returned by build_star_embedding),
which contains only bifurcation + terminal fate cells.
"""

from __future__ import annotations

import warnings
from typing import Dict, List, Optional

import numpy as np
import pandas as pd


# ---------------------------------------------------------------------------
# 1. Velocity-based driver genes
# ---------------------------------------------------------------------------

[docs] def get_velocity_drivers( adata_sub, fate_names: List[str], obs_key: str, root: str, n_top_genes: int = 50, ) -> Dict[str, pd.DataFrame]: """Rank genes by mean scVelo velocity in each fate arm's cells. Parameters ---------- adata_sub : AnnData Subset containing only bifurcation + terminal fate cells. Must have the 'velocity' layer (from scVelo). fate_names : list of str Terminal fate cluster labels. obs_key : str Column in adata_sub.obs with cluster labels. root : str Label of the progenitor cluster (used for context only). n_top_genes : int Number of top driver genes to print per fate. Returns ------- dict : fate_name -> DataFrame with columns [gene, mean_velocity, rank] Sorted by mean_velocity descending (most upregulated first). """ if "velocity" not in adata_sub.layers: raise ValueError( "'velocity' layer not found in adata_sub. " "Run the scVelo pipeline first (scorer.compute_velocity() or " "scvelo.tl.velocity())." ) import scipy.sparse as sp V_genes = adata_sub.layers["velocity"] if sp.issparse(V_genes): V_genes = V_genes.toarray() V_genes = np.asarray(V_genes, dtype=float) # (n_cells, n_genes) genes = adata_sub.var_names obs_labels = adata_sub.obs[obs_key].astype(str).values results: Dict[str, pd.DataFrame] = {} # Compute progenitor mean velocity once (used as baseline for delta) bif_mask = obs_labels == str(root) if bif_mask.sum() > 0: with warnings.catch_warnings(): warnings.simplefilter("ignore") mean_vel_progenitor = np.nanmean(V_genes[bif_mask, :], axis=0) else: mean_vel_progenitor = np.zeros(V_genes.shape[1]) for name in fate_names: mask = obs_labels == str(name) if mask.sum() == 0: warnings.warn( f"No cells found for fate '{name}' in adata_sub. Skipping.", stacklevel=2, ) continue V_fate = V_genes[mask, :] # (n_fate_cells, n_genes) with warnings.catch_warnings(): warnings.simplefilter("ignore") mean_vel_fate = np.nanmean(V_fate, axis=0) # Delta velocity: fate mean minus progenitor mean. # This removes genes that are constitutively active in the progenitor, # highlighting genes specifically upregulated along this fate arm. delta_vel = mean_vel_fate - mean_vel_progenitor df = pd.DataFrame({ "gene": genes, "mean_velocity": mean_vel_fate, "progenitor_velocity": mean_vel_progenitor, "delta_velocity": delta_vel, }).dropna(subset=["mean_velocity"]) # Sort by delta_velocity (fate-specific upregulation) df = df.sort_values("delta_velocity", ascending=False).reset_index(drop=True) df["rank"] = df.index + 1 results[name] = df print(f"\n── Velocity drivers: {name} (top {n_top_genes}, sorted by delta_velocity) ──") print( df.head(n_top_genes)[["rank", "gene", "delta_velocity", "mean_velocity"]] .to_string(index=False) ) return results
# --------------------------------------------------------------------------- # 2. DEG-based driver genes # ---------------------------------------------------------------------------
[docs] def get_deg_drivers( adata_sub, fate_names: List[str], obs_key: str, root: str, n_top_genes: int = 50, pval_threshold: float = 0.05, logfc_threshold: float = 0.25, ) -> Dict[str, pd.DataFrame]: """Find DEGs for each fate arm vs the bifurcation cluster (Wilcoxon). For each fate arm, compares arm cells against progenitor (bifurcation) cells using a Wilcoxon rank-sum test via scanpy. Parameters ---------- adata_sub : AnnData Subset containing only bifurcation + terminal fate cells. fate_names : list of str Terminal fate cluster labels. obs_key : str Column in adata_sub.obs with cluster labels. root : str Label of the progenitor cluster (reference group). n_top_genes : int Number of top significant DEGs to print per fate. pval_threshold : float Adjusted p-value threshold for significance. logfc_threshold : float Minimum absolute log fold-change for significance. Returns ------- dict : fate_name -> DataFrame with columns: [gene, logfoldchange, pval, pval_adj, significant] Sorted by logfoldchange descending. """ try: import scanpy as sc except ImportError: raise ImportError("scanpy is required for DEG analysis. pip install scanpy") obs_labels = adata_sub.obs[obs_key].astype(str).values results: Dict[str, pd.DataFrame] = {} for name in fate_names: fate_mask = obs_labels == str(name) bif_mask = obs_labels == str(root) sub_mask = fate_mask | bif_mask n_fate = fate_mask.sum() n_bif = bif_mask.sum() if n_fate < 5: warnings.warn( f"Fate '{name}' has only {n_fate} cells. " "Skipping DEG analysis (need ≥5).", stacklevel=2, ) continue if n_bif < 5: warnings.warn( f"Root cluster '{root}' has only {n_bif} cells. " "Skipping DEG analysis (need ≥5).", stacklevel=2, ) continue # Subset to fate + progenitor only for this pairwise comparison adata_pair = adata_sub[sub_mask].copy() adata_pair.obs["_deg_group"] = [ name if l == str(name) else "progenitor" for l in adata_pair.obs[obs_key].astype(str) ] try: sc.tl.rank_genes_groups( adata_pair, groupby="_deg_group", groups=[name], reference="progenitor", method="wilcoxon", key_added="rank_genes", pts=True, ) except Exception as e: warnings.warn( f"rank_genes_groups failed for fate '{name}': {e}", stacklevel=2, ) continue rg = adata_pair.uns["rank_genes"] df = pd.DataFrame({ "gene": rg["names"][name], "logfoldchange": rg["logfoldchanges"][name], "pval": rg["pvals"][name], "pval_adj": rg["pvals_adj"][name], }) df["significant"] = ( (df["pval_adj"] < pval_threshold) & (df["logfoldchange"].abs() > logfc_threshold) ) # Extract percent-expressed per group (available when pts=True) try: pts_dict = rg.get("pts", {}) pts_rest_dict = rg.get("pts_rest", {}) gene_order = list(rg["names"][name]) if name in pts_dict: pts_vals = pts_dict[name] # pts may be a dict keyed by gene name or an array in gene order if hasattr(pts_vals, "__getitem__") and not isinstance(pts_vals, np.ndarray): df["pct_fate"] = [float(pts_vals.get(g, np.nan)) for g in gene_order] else: df["pct_fate"] = np.asarray(pts_vals, dtype=float) else: df["pct_fate"] = np.nan if "progenitor" in pts_rest_dict: pts_rest_vals = pts_rest_dict["progenitor"] if hasattr(pts_rest_vals, "__getitem__") and not isinstance(pts_rest_vals, np.ndarray): df["pct_progenitor"] = [float(pts_rest_vals.get(g, np.nan)) for g in gene_order] else: df["pct_progenitor"] = np.asarray(pts_rest_vals, dtype=float) else: df["pct_progenitor"] = np.nan except Exception: df["pct_fate"] = np.nan df["pct_progenitor"] = np.nan df = df.sort_values("logfoldchange", ascending=False).reset_index(drop=True) results[name] = df n_sig = df["significant"].sum() n_up = ((df["logfoldchange"] > logfc_threshold) & df["significant"]).sum() n_dn = ((df["logfoldchange"] < -logfc_threshold) & df["significant"]).sum() print(f"\n── DEG drivers: {name} vs progenitor ──") print(f" Significant: {n_sig} (up: {n_up}, down: {n_dn})") sig_df = df[df["significant"]].head(n_top_genes) if len(sig_df) > 0: cols = ["gene", "logfoldchange", "pval_adj"] if "pct_fate" in sig_df.columns and not sig_df["pct_fate"].isna().all(): cols += ["pct_fate", "pct_progenitor"] print(sig_df[cols].to_string(index=False)) else: print(" (no significant DEGs at current thresholds)") return results
# --------------------------------------------------------------------------- # 3. Velocity-fate correlation drivers (CellRank-style) # ---------------------------------------------------------------------------
[docs] def get_velocity_fate_drivers( adata_sub, cell_scores: np.ndarray, fate_names: List[str], obs_key: str, root: str, n_top_genes: int = 50, pval_threshold: float = 0.05, min_cells: int = 10, ) -> Dict[str, pd.DataFrame]: """Identify driver genes by correlating gene velocity with fate affinity. For each fate arm, computes the Spearman correlation between each gene's velocity (from the 'velocity' layer) and the cell's fate affinity score (from cell_scores[:, j]). Genes with high positive Spearman correlation are being upregulated specifically as cells commit to that fate — a stronger signal than mean velocity alone, because it filters out genes that are fast everywhere. Algorithm --------- 1. For each fate j, extract velocity matrix V (n_cells × n_genes). 2. Extract fate affinity vector a (n_cells,) = cell_scores[:, j]. 3. Compute Spearman correlation between a and each gene's velocity column. 4. Compute FDR-corrected p-values (Benjamini-Hochberg via statsmodels). 5. Return DataFrame sorted by spearman_r descending. Parameters ---------- adata_sub : AnnData Subset containing only bifurcation + terminal fate cells. Must have the 'velocity' layer (from scVelo). cell_scores : np.ndarray, shape (n_cells, k) Per-cell fate affinity scores from CommitmentScoreResult.cell_scores. fate_names : list of str Terminal fate cluster labels (length k). obs_key : str Column in adata_sub.obs with cluster labels. root : str Label of the progenitor cluster. n_top_genes : int Number of top driver genes to print per fate. pval_threshold : float FDR-adjusted p-value threshold for significance. min_cells : int Minimum number of cells required to compute correlations. Returns ------- dict : fate_name -> DataFrame with columns: [gene, spearman_r, pval, pval_adj, mean_velocity, delta_velocity, significant] Sorted by spearman_r descending. """ if "velocity" not in adata_sub.layers: raise ValueError( "'velocity' layer not found in adata_sub. " "Run the scVelo pipeline first (scorer.compute_velocity() or " "scvelo.tl.velocity())." ) if cell_scores is None: raise ValueError( "cell_scores is None. " "Run scorer.score(cell_level=True) first." ) if cell_scores.shape[0] != adata_sub.n_obs: raise ValueError( f"cell_scores has {cell_scores.shape[0]} rows but " f"adata_sub has {adata_sub.n_obs} cells." ) if cell_scores.shape[1] != len(fate_names): raise ValueError( f"cell_scores has {cell_scores.shape[1]} columns but " f"fate_names has {len(fate_names)} entries." ) import scipy.sparse as sp from scipy.stats import spearmanr try: from statsmodels.stats.multitest import multipletests _has_statsmodels = True except ImportError: _has_statsmodels = False warnings.warn( "statsmodels not found. p-values will not be FDR-corrected. " "pip install statsmodels", stacklevel=2, ) V_genes = adata_sub.layers["velocity"] if sp.issparse(V_genes): V_genes = V_genes.toarray() V_genes = np.asarray(V_genes, dtype=float) # (n_cells, n_genes) genes = np.array(adata_sub.var_names) obs_labels = adata_sub.obs[obs_key].astype(str).values # Progenitor mean velocity for delta computation bif_mask = obs_labels == str(root) if bif_mask.sum() > 0: with warnings.catch_warnings(): warnings.simplefilter("ignore") mean_vel_progenitor = np.nanmean(V_genes[bif_mask, :], axis=0) else: mean_vel_progenitor = np.zeros(V_genes.shape[1]) results: Dict[str, pd.DataFrame] = {} for j, name in enumerate(fate_names): affinity = cell_scores[:, j] # (n_cells,) # Filter to cells with non-NaN affinity valid = ~np.isnan(affinity) if valid.sum() < min_cells: warnings.warn( f"Only {valid.sum()} valid cells for fate '{name}'. " f"Need ≥{min_cells}. Skipping.", stacklevel=2, ) continue V_sub = V_genes[valid, :] # (n_valid, n_genes) a_sub = affinity[valid] # (n_valid,) n_genes = V_sub.shape[1] rho = np.zeros(n_genes) pvals = np.ones(n_genes) # Compute Spearman r for each gene for g in range(n_genes): v_g = V_sub[:, g] # Skip genes with zero variance in velocity if np.nanstd(v_g) < 1e-10: continue try: r, p = spearmanr(a_sub, v_g, nan_policy="omit") rho[g] = float(r) if np.isfinite(r) else 0.0 pvals[g] = float(p) if np.isfinite(p) else 1.0 except Exception: pass # FDR correction if _has_statsmodels: _, pvals_adj, _, _ = multipletests(pvals, method="fdr_bh") else: pvals_adj = pvals.copy() # Mean velocity in fate arm cells fate_mask = obs_labels == str(name) if fate_mask.sum() > 0: with warnings.catch_warnings(): warnings.simplefilter("ignore") mean_vel_fate = np.nanmean(V_genes[fate_mask, :], axis=0) else: mean_vel_fate = np.nanmean(V_genes[valid, :], axis=0) delta_vel = mean_vel_fate - mean_vel_progenitor df = pd.DataFrame({ "gene": genes, "spearman_r": rho, "pval": pvals, "pval_adj": pvals_adj, "mean_velocity": mean_vel_fate, "delta_velocity": delta_vel, }) df["significant"] = df["pval_adj"] < pval_threshold df = df.sort_values("spearman_r", ascending=False).reset_index(drop=True) df["rank"] = df.index + 1 results[name] = df n_sig = df["significant"].sum() print(f"\n── Velocity-fate drivers: {name} (top {n_top_genes}, sorted by Spearman r) ──") print(f" Significant (FDR < {pval_threshold}): {n_sig} / {len(df)}") top = df[df["significant"]].head(n_top_genes) if len(top) == 0: top = df.head(n_top_genes) print( top[["rank", "gene", "spearman_r", "pval_adj", "mean_velocity"]] .to_string(index=False) ) return results