Source code for scCS.multicomparison

"""
multicomparison.py — MultiScorer: multi-condition (3+) commitment score analysis for scCS.

Extends the pairwise PairScorer to handle 3 or more experimental conditions
(e.g., multiple drug treatments, time points, genotypes).

Key design principle: shared embedding
---------------------------------------
All conditions are embedded in a SINGLE shared star layout built on the pooled
data.  This is critical — if each condition had its own embedding, the arm
angles would differ and CS values would not be comparable across conditions.

Architecture
------------
MultiScorer
    Wraps SingleScorer.  Pools all conditions for embedding, then scores
    each condition separately using cell masks on the shared embedding.

Tier 1 — Core multi-condition API
    score_all_conditions()          : dict[condition -> CommitmentScoreResult]

Tier 2 — Omnibus + post-hoc statistical comparison
    compare_omnibus()               : Kruskal-Wallis / ANOVA per fate
    compare_posthoc()               : Dunn / Tukey / Conover pairwise post-hoc
    compute_pairwise_deltas()       : ΔCS for ALL condition pairs with bootstrap CI

Tier 3 — Advanced
    fit_mixed_model()               : linear mixed-effects model on per-cell
                                      fate affinity scores via statsmodels MixedLM
    fit_mixed_model_contrasts()     : LMM with custom condition contrasts
    trajectory_shift()              : KS test + Wasserstein distance on
                                      pseudotime distributions per fate arm
    plot_trajectory_shift()         : visualization of pseudotime distributions

Usage
-----
>>> mscorer = scCS.MultiScorer(
...     adata,
...     root='17',
...     branches=['homeostatic', 'activated'],
...     condition_obs_key='treatment',
...     obs_key='leiden',
... )
>>> mscorer.build_embedding(ordering_metric='pseudotime')
>>> mscorer.fit()
>>> results = mscorer.score_all_conditions()
>>> omnibus = mscorer.compare_omnibus(results)
>>> posthoc = mscorer.compare_posthoc(results, omnibus_results=omnibus)
>>> deltas = mscorer.compute_pairwise_deltas()
>>> shift = mscorer.trajectory_shift(results)
"""

from __future__ import annotations

import warnings
from itertools import combinations
from typing import Dict, List, Literal, Optional, Tuple, Union

import matplotlib.figure
import numpy as np
import pandas as pd

from .single import SingleScorer
from .scores import (
    CommitmentScoreResult,
    compute_cell_scores,
    compute_magnitudes,
    compute_angles,
    bin_angles,
    compute_sector_magnitudes,
    compute_pairwise_cs_matrix,
    centroid_sectors,
    equal_sectors,
)
from .plot import _fate_colors, _condition_colors, CONDITION_PALETTE


# ---------------------------------------------------------------------------
# MultiScorer
# ---------------------------------------------------------------------------

[docs] class MultiScorer: """RNA velocity commitment scorer for experiments with 3+ conditions. Builds a SHARED star embedding on the pooled data from all conditions, then scores each condition separately. This ensures arm geometry is identical across conditions, making CS values directly comparable. Provides tiered statistical testing: - Tier 2: Omnibus tests (Kruskal-Wallis / ANOVA) followed by post-hoc pairwise comparisons (Dunn / Tukey / Conover). - Tier 3: Mixed-effects models with custom contrasts, trajectory shift analysis. Parameters ---------- adata : AnnData Full single-cell dataset containing all conditions. root : str Label of the progenitor/root cluster in adata.obs[obs_key]. branches : list of str Labels of the k terminal fate clusters. condition_obs_key : str Column in adata.obs with condition labels (e.g., 'treatment'). Must contain at least 3 unique values. obs_key : str Column in adata.obs with cluster labels. Default: 'leiden'. n_angle_bins : int Number of angular bins. Default: 36. sector_method : {'centroid', 'equal'} Sector definition strategy. copy : bool Work on a copy of adata. Raises ------ ValueError If condition_obs_key has fewer than 3 unique values. For 2 conditions, use PairScorer instead. Examples -------- >>> mscorer = MultiScorer( ... adata, ... root='17', ... branches=['homeostatic', 'activated'], ... condition_obs_key='treatment', ... obs_key='leiden', ... ) >>> mscorer.build_embedding(ordering_metric='pseudotime') >>> mscorer.fit() >>> results = mscorer.score_all_conditions() >>> omnibus = mscorer.compare_omnibus(results) >>> posthoc = mscorer.compare_posthoc(results, omnibus_results=omnibus) """ def __init__( self, adata, root: str, branches: List[str], condition_obs_key: str, obs_key: str = "leiden", n_angle_bins: int = 36, sector_method: Literal["centroid", "equal"] = "centroid", copy: bool = False, ):
[docs] self.adata = adata.copy() if copy else adata
[docs] self.root = str(root)
[docs] self.branches = list(branches)
[docs] self.condition_obs_key = condition_obs_key
[docs] self.obs_key = obs_key
[docs] self.n_angle_bins = n_angle_bins
[docs] self.sector_method = sector_method
# Validate condition key if condition_obs_key not in adata.obs: raise ValueError( f"condition_obs_key='{condition_obs_key}' not found in adata.obs. " f"Available columns: {list(adata.obs.columns)}" )
[docs] self.conditions = sorted(adata.obs[condition_obs_key].astype(str).unique().tolist())
# Validate at least 3 conditions if len(self.conditions) < 3: raise ValueError( f"MultiScorer requires at least 3 conditions, but " f"condition_obs_key='{condition_obs_key}' has {len(self.conditions)} " f"unique value(s): {self.conditions}. " + ("Use SingleScorer for single-condition analysis. " if len(self.conditions) == 1 else "Use PairScorer for 2-condition analysis. ") ) # Internal SingleScorer built on pooled data self._scorer: Optional[SingleScorer] = None self._fitted = False print( f"[scCS] MultiScorer initialized.\n" f" Conditions ({len(self.conditions)}): {self.conditions}\n" f" Root: '{root}', " f"Branches: {branches}" ) # ------------------------------------------------------------------ # Step 1: Build shared embedding (delegates to SingleScorer) # ------------------------------------------------------------------
[docs] def build_embedding( self, ordering_metric: Union[str, np.ndarray] = "pseudotime", invert_ordering: bool = False, scale_ordering: bool = False, arm_scale: float = 10.0, jitter: float = 0.3, seed: int = 42, arm_norm: str = "global", verbose: bool = True, ) -> "MultiScorer": """Build the shared star embedding on pooled data from all conditions. The embedding is built on ALL cells (all conditions pooled), ensuring that arm geometry is identical across conditions. Parameters ---------- ordering_metric : str or np.ndarray See SingleScorer.build_embedding(). invert_ordering : bool scale_ordering : bool arm_scale : float jitter : float seed : int verbose : bool Returns ------- self """ if verbose: print( f"[scCS] Building SHARED embedding on pooled data " f"({self.adata.n_obs} cells, {len(self.conditions)} conditions)..." ) self._scorer = SingleScorer( self.adata, root=self.root, branches=self.branches, obs_key=self.obs_key, n_angle_bins=self.n_angle_bins, sector_method=self.sector_method, copy=False, ) self._scorer.build_embedding( ordering_metric=ordering_metric, invert_ordering=invert_ordering, scale_ordering=scale_ordering, arm_scale=arm_scale, jitter=jitter, seed=seed, arm_norm=arm_norm, verbose=verbose, ) return self
[docs] def refit_pseudotime( self, scale_01: bool = True, arm_scale: float = 10.0, jitter: float = 0.3, seed: int = 42, arm_norm: str = "global", verbose: bool = True, ) -> "MultiScorer": """Rebuild the shared embedding using subset-local pseudotime. See SingleScorer.refit_pseudotime(). """ self._check_embedding() self._scorer.refit_pseudotime( scale_01=scale_01, arm_scale=arm_scale, jitter=jitter, seed=seed, arm_norm=arm_norm, verbose=verbose, ) self._fitted = False return self
# ------------------------------------------------------------------ # Step 2: Fit (delegates to SingleScorer) # ------------------------------------------------------------------
[docs] def fit(self, verbose: bool = True) -> "MultiScorer": """Fit the shared FateMap and project velocity. Must be called after build_embedding(). Returns ------- self """ self._check_embedding() self._scorer.fit(verbose=verbose) self._fitted = True return self
# ------------------------------------------------------------------ # Tier 1: Score all conditions # ------------------------------------------------------------------
[docs] def score_all_conditions( self, cell_level: bool = True, k_nn: Optional[int] = None, n_bootstrap: int = 0, bootstrap_ci: float = 0.95, verbose: bool = True, ) -> Dict[str, CommitmentScoreResult]: """Compute commitment scores separately for each condition. Uses the shared embedding and FateMap. Each condition's cells are masked from the shared adata_sub, so arm geometry is identical. Parameters ---------- cell_level : bool Compute per-cell fate affinity scores. k_nn : int, optional NN-smoothed entropy neighbors. n_bootstrap : int Bootstrap replicates for CI. 0 = disabled. bootstrap_ci : float CI level for bootstrap. verbose : bool Returns ------- dict : condition_label -> CommitmentScoreResult """ self._check_fitted() results: Dict[str, CommitmentScoreResult] = {} for cond in self.conditions: mask = ( self._scorer.adata_sub.obs[self.condition_obs_key].astype(str) == cond ).values n_cond = mask.sum() if n_cond < 10: warnings.warn( f"Condition '{cond}' has only {n_cond} cells in the embedding. " "Skipping.", stacklevel=2, ) continue if verbose: print(f"\n[scCS] Scoring condition: '{cond}' ({n_cond} cells)...") results[cond] = self._scorer.score( cell_mask=mask, cell_level=cell_level, k_nn=k_nn, n_bootstrap=n_bootstrap, bootstrap_ci=bootstrap_ci, verbose=verbose, write_to_obs=False, ) return results
# ------------------------------------------------------------------ # Tier 2: Omnibus + post-hoc statistical comparison # ------------------------------------------------------------------ def _per_condition_cell_scores( self, results: Dict[str, "CommitmentScoreResult"], j: int, ) -> Dict[str, np.ndarray]: """Return cell_scores[:, j] sliced to each condition's cells. SingleScorer.score returns cell_scores sized to the FULL shared adata_sub embedding (so per-cell affinities reflect every cell's velocity vector). For per-condition statistical tests we want each group to contain only that condition's cells. Falls back to the raw column when the array is already condition-sized (length differs from adata_sub.n_obs). """ _adata_sub = self._scorer.adata_sub groups: Dict[str, np.ndarray] = {} for c, res in results.items(): cs = res.cell_scores if cs is None: groups[c] = np.empty(0) continue if cs.shape[0] == _adata_sub.n_obs: m = ( _adata_sub.obs[self.condition_obs_key].astype(str) == c ).values groups[c] = cs[m, j] else: groups[c] = cs[:, j] return groups
[docs] def compare_omnibus( self, results: Dict[str, CommitmentScoreResult], test: Literal["kruskal", "anova"] = "kruskal", pval_threshold: float = 0.05, verbose: bool = True, ) -> pd.DataFrame: """Omnibus test across all conditions per fate. For each fate arm, tests whether per-cell affinity scores differ across ALL conditions simultaneously. - 'kruskal': Kruskal-Wallis H test (non-parametric, recommended default) - 'anova': One-way ANOVA (parametric, assumes normality) Parameters ---------- results : dict Output of score_all_conditions() with cell_level=True. test : {'kruskal', 'anova'} Statistical test to use. Default: 'kruskal'. pval_threshold : float Significance threshold for flagging. Default 0.05. verbose : bool Returns ------- pd.DataFrame with columns: fate, test, statistic, pval, pval_adj, significant, n_conditions """ self._check_fitted() # Validate that cell_scores are available for cond, res in results.items(): if res.cell_scores is None: raise ValueError( f"cell_scores not available for condition '{cond}'. " "Re-run score_all_conditions(cell_level=True)." ) fate_names = list(results.values())[0].fate_names cond_list = list(results.keys()) n_cond = len(cond_list) rows = [] for j, fate in enumerate(fate_names): cond_groups = self._per_condition_cell_scores(results, j) groups = [cond_groups[c] for c in cond_list] if test == "kruskal": from scipy.stats import kruskal try: stat, pval = kruskal(*groups) except Exception: stat, pval = np.nan, np.nan test_name = "kruskal-wallis" else: # anova from scipy.stats import f_oneway try: stat, pval = f_oneway(*groups) except Exception: stat, pval = np.nan, np.nan test_name = "one-way-anova" rows.append({ "fate": fate, "test": test_name, "statistic": float(stat), "pval": float(pval), "n_conditions": n_cond, }) df = pd.DataFrame(rows) # Multiple testing correction across fates (Bonferroni) df["pval_adj"] = np.minimum(df["pval"] * len(fate_names), 1.0) df["significant"] = df["pval_adj"] < pval_threshold if verbose: print(f"\n=== Omnibus test ({test}) across {n_cond} conditions ===") sig = df[df["significant"]] print(f" Significant fates: {len(sig)} / {len(df)}") if len(sig) > 0: print(sig[["fate", "test", "statistic", "pval", "pval_adj", "significant"]].to_string(index=False)) else: print(f" No significant differences at pval_adj < {pval_threshold}.") return df
[docs] def compare_posthoc( self, results: Dict[str, CommitmentScoreResult], omnibus_results: Optional[pd.DataFrame] = None, method: Literal["dunn", "tukey", "conover"] = "dunn", pval_correction: Literal["fdr", "bonferroni", "holm"] = "fdr", pval_threshold: float = 0.05, verbose: bool = True, ) -> pd.DataFrame: """Post-hoc pairwise comparisons across conditions per fate. Only meaningful after an omnibus test rejects H0. If omnibus_results is provided, post-hoc is only run for fates where omnibus p < threshold. Methods ------- - 'dunn': Dunn's test with rank-based comparisons (non-parametric, recommended with Kruskal-Wallis). Uses scikit-posthocs. - 'tukey': Tukey HSD (parametric, for balanced designs, with ANOVA). - 'conover': Conover-Iman test (more powerful than Dunn, non-parametric). Uses scikit-posthocs. Multiple testing correction applied across all pairwise comparisons within each fate arm. Parameters ---------- results : dict Output of score_all_conditions() with cell_level=True. omnibus_results : pd.DataFrame, optional Output of compare_omnibus(). If provided, post-hoc is only run for fates where omnibus pval_adj < pval_threshold. method : {'dunn', 'tukey', 'conover'} Post-hoc test method. Default: 'dunn'. pval_correction : {'fdr', 'bonferroni', 'holm'} Multiple testing correction method. Default: 'fdr'. pval_threshold : float Significance threshold. Default 0.05. verbose : bool Returns ------- pd.DataFrame with columns: fate, comparison, method, statistic, pval, pval_adj, significant, mean_A, mean_B, delta_mean """ self._check_fitted() # Validate that cell_scores are available for cond, res in results.items(): if res.cell_scores is None: raise ValueError( f"cell_scores not available for condition '{cond}'. " "Re-run score_all_conditions(cell_level=True)." ) fate_names = list(results.values())[0].fate_names cond_list = list(results.keys()) # Filter to significant fates if omnibus provided if omnibus_results is not None: sig_fates = set( omnibus_results[omnibus_results["significant"]]["fate"].tolist() ) if not sig_fates: if verbose: print("[scCS] No significant fates from omnibus test. Skipping post-hoc.") return pd.DataFrame() fate_names = [f for f in fate_names if f in sig_fates] rows = [] for j, fate in enumerate(fate_names): # Per-condition cell_scores (sliced from shared embedding if needed) cond_groups = self._per_condition_cell_scores(results, j) # Build group labels and values for this fate group_labels = [] group_values = [] for cond in cond_list: vals = cond_groups[cond] group_labels.extend([cond] * len(vals)) group_values.extend(vals.tolist()) group_labels = np.array(group_labels) group_values = np.array(group_values) if method == "dunn": try: import scikit_posthocs as sp # Dunn's test returns a matrix of p-values dunn_p = sp.posthoc_dunn( [cond_groups[c] for c in cond_list], p_adjust=None, # we do our own correction ) # dunn_p is a DataFrame with integer indices for i_a, ca in enumerate(cond_list): for i_b, cb in enumerate(cond_list): if i_a >= i_b: continue pval = dunn_p.iloc[i_a, i_b] mean_a = cond_groups[ca].mean() mean_b = cond_groups[cb].mean() rows.append({ "fate": fate, "comparison": f"{ca} vs {cb}", "method": "dunn", "statistic": np.nan, # Dunn doesn't report per-pair stat "pval": float(pval), "mean_A": float(mean_a), "mean_B": float(mean_b), "delta_mean": float(mean_a - mean_b), }) except ImportError: raise ImportError( "scikit-posthocs is required for Dunn's test. " "Install it with: pip install scikit-posthocs" ) elif method == "conover": try: import scikit_posthocs as sp conover_p = sp.posthoc_conover( [cond_groups[c] for c in cond_list], p_adjust=None, ) for i_a, ca in enumerate(cond_list): for i_b, cb in enumerate(cond_list): if i_a >= i_b: continue pval = conover_p.iloc[i_a, i_b] mean_a = cond_groups[ca].mean() mean_b = cond_groups[cb].mean() rows.append({ "fate": fate, "comparison": f"{ca} vs {cb}", "method": "conover", "statistic": np.nan, "pval": float(pval), "mean_A": float(mean_a), "mean_B": float(mean_b), "delta_mean": float(mean_a - mean_b), }) except ImportError: raise ImportError( "scikit-posthocs is required for Conover-Iman test. " "Install it with: pip install scikit-posthocs" ) elif method == "tukey": from scipy.stats import tukey_hsd groups = [cond_groups[c] for c in cond_list] try: res = tukey_hsd(*groups) # res.pvalue is an (n_groups, n_groups) array for i_a, ca in enumerate(cond_list): for i_b, cb in enumerate(cond_list): if i_a >= i_b: continue pval = res.pvalue[i_a, i_b] mean_a = cond_groups[ca].mean() mean_b = cond_groups[cb].mean() rows.append({ "fate": fate, "comparison": f"{ca} vs {cb}", "method": "tukey-hsd", "statistic": float(res.statistic[i_a, i_b]), "pval": float(pval), "mean_A": float(mean_a), "mean_B": float(mean_b), "delta_mean": float(mean_a - mean_b), }) except Exception as e: warnings.warn( f"Tukey HSD failed for fate '{fate}': {e}", stacklevel=2, ) if not rows: return pd.DataFrame() df = pd.DataFrame(rows) # Multiple testing correction if pval_correction == "fdr": # Benjamini-Hochberg FDR pvals = df["pval"].values.copy() n = len(pvals) sorted_idx = np.argsort(pvals) sorted_pvals = pvals[sorted_idx] fdr = np.minimum(sorted_pvals * n / np.arange(1, n + 1), 1.0) # Enforce monotonicity (from right to left) for i in range(n - 2, -1, -1): fdr[i] = min(fdr[i], fdr[i + 1]) # Map back to original order pval_adj = np.empty(n) pval_adj[sorted_idx] = fdr df["pval_adj"] = pval_adj elif pval_correction == "bonferroni": df["pval_adj"] = np.minimum(df["pval"] * len(df), 1.0) elif pval_correction == "holm": # Holm-Bonferroni step-down pvals = df["pval"].values.copy() n = len(pvals) sorted_idx = np.argsort(pvals) sorted_pvals = pvals[sorted_idx] holm = np.minimum(sorted_pvals * np.arange(n, 0, -1), 1.0) # Enforce monotonicity for i in range(1, n): holm[i] = max(holm[i], holm[i - 1]) pval_adj = np.empty(n) pval_adj[sorted_idx] = holm df["pval_adj"] = pval_adj df["significant"] = df["pval_adj"] < pval_threshold if verbose: print(f"\n=== Post-hoc comparison ({method}, correction={pval_correction}) ===") sig = df[df["significant"]] print(f" Significant pairs: {len(sig)} / {len(df)}") if len(sig) > 0: print(sig[["fate", "comparison", "method", "pval", "pval_adj", "delta_mean"]].to_string(index=False)) else: print(f" No significant pairs at pval_adj < {pval_threshold}.") return df
[docs] def compute_pairwise_deltas( self, n_bootstrap: int = 500, ci: float = 0.95, seed: int = 42, verbose: bool = True, ) -> Dict[Tuple[str, str], Dict]: """Compute ΔCS for ALL condition pairs with bootstrap CI. Unlike PairScorer.compute_delta_CS() which takes two specific conditions, this computes delta for every pair in the condition set. Parameters ---------- n_bootstrap : int Number of bootstrap replicates. Default 500. ci : float Confidence interval level. Default 0.95. seed : int verbose : bool Returns ------- dict mapping (cond_a, cond_b) -> delta_result dict (same structure as PairScorer.compute_delta_CS() output). """ self._check_fitted() fate_map = self._scorer._fate_map vx = self._scorer._vx vy = self._scorer._vy # Sector definition (shared) if self.sector_method == "centroid": sectors, _ = centroid_sectors( fate_map.fate_centroids, fate_map.root_centroid, n_bins=self.n_angle_bins, ) else: sectors = equal_sectors(fate_map.k, n_bins=self.n_angle_bins) def _score_mask(mask, n_cells_per_fate): vx_m, vy_m = vx[mask], vy[mask] mag = compute_magnitudes(vx_m, vy_m) ang = compute_angles(vx_m, vy_m) _, M_bin = bin_angles(ang, mag, n_bins=self.n_angle_bins) M_sec = compute_sector_magnitudes(M_bin, sectors) return compute_pairwise_cs_matrix( M_sec, n_cells_per_fate=n_cells_per_fate, normalized=True ) all_deltas = {} rng = np.random.default_rng(seed) alpha = (1.0 - ci) / 2.0 k = fate_map.k for ca, cb in combinations(self.conditions, 2): mask_a = ( self._scorer.adata_sub.obs[self.condition_obs_key].astype(str) == ca ).values mask_b = ( self._scorer.adata_sub.obs[self.condition_obs_key].astype(str) == cb ).values n_cells_a = np.array([ int(mask_a[idx].sum()) for idx in fate_map.fate_cell_indices ], dtype=float) n_cells_b = np.array([ int(mask_b[idx].sum()) for idx in fate_map.fate_cell_indices ], dtype=float) # Point estimates nCS_A = _score_mask(mask_a, n_cells_a) nCS_B = _score_mask(mask_b, n_cells_b) delta = nCS_A - nCS_B # Bootstrap idx_a = np.where(mask_a)[0] idx_b = np.where(mask_b)[0] boot_deltas = np.zeros((n_bootstrap, k, k)) for b in range(n_bootstrap): boot_a = rng.choice(idx_a, size=len(idx_a), replace=True) boot_b = rng.choice(idx_b, size=len(idx_b), replace=True) mask_ba = np.zeros(len(vx), dtype=bool) mask_ba[boot_a] = True mask_bb = np.zeros(len(vx), dtype=bool) mask_bb[boot_b] = True nCS_ba = _score_mask(mask_ba, n_cells_a) nCS_bb = _score_mask(mask_bb, n_cells_b) boot_deltas[b] = nCS_ba - nCS_bb boot_deltas = np.where(np.isinf(boot_deltas), np.nan, boot_deltas) all_deltas[(ca, cb)] = { "delta_nCS": delta, "ci_low": np.nanpercentile(boot_deltas, alpha * 100, axis=0), "ci_high": np.nanpercentile(boot_deltas, (1 - alpha) * 100, axis=0), "nCS_A": nCS_A, "nCS_B": nCS_B, "fate_names": fate_map.fate_names, "condition_a": ca, "condition_b": cb, "n_bootstrap": n_bootstrap, "ci_level": ci, } if verbose: ci_pct = int(ci * 100) print(f"\n=== ΔCS: '{ca}' − '{cb}' ===") df_delta = pd.DataFrame( delta, index=fate_map.fate_names, columns=fate_map.fate_names ) print(" ΔnCS (point estimate):") print(df_delta.round(3).to_string()) return all_deltas
# ------------------------------------------------------------------ # Tier 3: Advanced — mixed-effects model # ------------------------------------------------------------------
[docs] def fit_mixed_model( self, results: Dict[str, CommitmentScoreResult], replicate_key: Optional[str] = None, ref_condition: Optional[str] = None, verbose: bool = True, ) -> pd.DataFrame: """Linear mixed-effects model on per-cell fate affinity scores. Models per-cell fate affinity as a function of condition (fixed effect) with optional sample/replicate as a random effect. Model (per fate j): affinity_ij ~ condition_i + (1 | sample_id_i) Uses statsmodels MixedLM. Parameters ---------- results : dict Output of score_all_conditions(cell_level=True). replicate_key : str, optional Column in adata_sub.obs with sample/replicate IDs. ref_condition : str, optional Reference condition for the fixed effect. verbose : bool Returns ------- pd.DataFrame with columns: fate, condition, coef, std_err, z_score, pval, pval_adj, ci_low, ci_high, significant """ try: import statsmodels.formula.api as smf except ImportError: raise ImportError( "statsmodels is required for mixed-effects modeling. " "pip install statsmodels" ) fate_names = list(results.values())[0].fate_names conditions = list(results.keys()) if ref_condition is None: ref_condition = sorted(conditions)[0] if ref_condition not in conditions: raise ValueError( f"ref_condition='{ref_condition}' not in conditions: " f"{conditions}" ) # Build long-form DataFrame with cell-level data # When res.cell_scores comes from the shared embedding it covers # every cell — slice to this condition's cells via condition_obs_key. _adata_sub = self._scorer.adata_sub rows = [] for cond, res in results.items(): if res.cell_scores is None: continue obs_names = ( np.asarray(res.cell_obs_names) if res.cell_obs_names is not None else np.array([]) ) cs = res.cell_scores if cs.shape[0] == _adata_sub.n_obs: m = (_adata_sub.obs[self.condition_obs_key].astype(str) == cond).values cs = cs[m] obs_names = np.asarray(_adata_sub.obs_names)[m] for i, obs_name in enumerate(obs_names): row = {"condition": cond, "obs_name": obs_name} for j, fate in enumerate(fate_names): row[f"affinity_{fate}"] = cs[i, j] if replicate_key is not None and replicate_key in _adata_sub.obs: row["sample_id"] = str( _adata_sub.obs.loc[obs_name, replicate_key] if obs_name in _adata_sub.obs_names else "unknown" ) else: row["sample_id"] = obs_name rows.append(row) df_long = pd.DataFrame(rows) # Set reference condition df_long["condition"] = pd.Categorical( df_long["condition"], categories=[ref_condition] + [c for c in conditions if c != ref_condition], ) all_rows = [] for fate in fate_names: col = f"affinity_{fate}" if col not in df_long.columns: continue try: if replicate_key is not None: model = smf.mixedlm( f"{col} ~ C(condition)", data=df_long, groups=df_long["sample_id"], ) fit = model.fit(reml=True, method="lbfgs") else: model = smf.ols(f"{col} ~ C(condition)", data=df_long) fit = model.fit() for param_name in fit.params.index: if "condition" not in param_name: continue cond_label = ( param_name .replace(f"C(condition)[T.", "") .replace("]", "") .strip() ) ci = fit.conf_int().loc[param_name] all_rows.append({ "fate": fate, "condition": cond_label, "reference": ref_condition, "coef": float(fit.params[param_name]), "std_err": float(fit.bse[param_name]), "z_score": float(fit.tvalues[param_name]), "pval": float(fit.pvalues[param_name]), "ci_low": float(ci.iloc[0]), "ci_high": float(ci.iloc[1]), }) except Exception as e: warnings.warn( f"Mixed model failed for fate '{fate}': {e}", stacklevel=2, ) if not all_rows: return pd.DataFrame() result_df = pd.DataFrame(all_rows) result_df["pval_adj"] = np.minimum(result_df["pval"] * len(result_df), 1.0) result_df["significant"] = result_df["pval_adj"] < 0.05 result_df = result_df.sort_values(["fate", "pval_adj"]).reset_index(drop=True) if verbose: print("\n=== Mixed-effects model results ===") sig = result_df[result_df["significant"]] print(f" Significant effects: {len(sig)} / {len(result_df)}") if len(sig) > 0: print( sig[["fate", "condition", "coef", "std_err", "pval_adj"]] .to_string(index=False) ) return result_df
[docs] def fit_mixed_model_contrasts( self, results: Dict[str, CommitmentScoreResult], contrasts: Optional[List[Tuple[str, str]]] = None, replicate_key: Optional[str] = None, ref_condition: Optional[str] = None, pval_threshold: float = 0.05, verbose: bool = True, ) -> pd.DataFrame: """Linear mixed-effects model with custom condition contrasts. Extends fit_mixed_model() to test specific condition comparisons within the LMM framework (more powerful than separate models). If contrasts is None, tests each condition vs ref_condition. If contrasts is provided, tests each specified pair, e.g.: [('drug_A', 'control'), ('drug_B', 'control'), ('drug_A', 'drug_B')] Uses statsmodels MixedLM with Wald tests on contrast coefficients. Parameters ---------- results : dict Output of score_all_conditions(cell_level=True). contrasts : list of (str, str), optional Pairs of conditions to compare. If None, all conditions vs ref_condition are tested. replicate_key : str, optional Column in adata_sub.obs with sample/replicate IDs. ref_condition : str, optional Reference condition. Required when contrasts is None. pval_threshold : float Significance threshold. Default 0.05. verbose : bool Returns ------- pd.DataFrame with columns: fate, contrast, coef, std_err, z_score, pval, pval_adj, significant """ try: import statsmodels.formula.api as smf except ImportError: raise ImportError( "statsmodels is required for mixed-effects modeling. " "pip install statsmodels" ) fate_names = list(results.values())[0].fate_names conditions = list(results.keys()) if ref_condition is None: ref_condition = sorted(conditions)[0] # Build contrasts list if contrasts is None: contrasts = [(c, ref_condition) for c in conditions if c != ref_condition] # Validate contrasts for ca, cb in contrasts: if ca not in conditions or cb not in conditions: raise ValueError( f"Contrast ('{ca}', '{cb}') contains condition not in " f"available conditions: {conditions}" ) # Build long-form DataFrame # When res.cell_scores comes from the shared embedding it covers # every cell — slice to this condition's cells via condition_obs_key. _adata_sub = self._scorer.adata_sub rows = [] for cond, res in results.items(): if res.cell_scores is None: continue obs_names = ( np.asarray(res.cell_obs_names) if res.cell_obs_names is not None else np.array([]) ) cs = res.cell_scores if cs.shape[0] == _adata_sub.n_obs: m = (_adata_sub.obs[self.condition_obs_key].astype(str) == cond).values cs = cs[m] obs_names = np.asarray(_adata_sub.obs_names)[m] for i, obs_name in enumerate(obs_names): row = {"condition": cond, "obs_name": obs_name} for j, fate in enumerate(fate_names): row[f"affinity_{fate}"] = cs[i, j] if replicate_key is not None and replicate_key in _adata_sub.obs: row["sample_id"] = str( _adata_sub.obs.loc[obs_name, replicate_key] if obs_name in _adata_sub.obs_names else "unknown" ) else: row["sample_id"] = obs_name rows.append(row) df_long = pd.DataFrame(rows) # Set reference condition df_long["condition"] = pd.Categorical( df_long["condition"], categories=[ref_condition] + [c for c in conditions if c != ref_condition], ) all_rows = [] for fate in fate_names: col = f"affinity_{fate}" if col not in df_long.columns: continue try: if replicate_key is not None: model = smf.mixedlm( f"{col} ~ C(condition)", data=df_long, groups=df_long["sample_id"], ) fit = model.fit(reml=True, method="lbfgs") else: model = smf.ols(f"{col} ~ C(condition)", data=df_long) fit = model.fit() # For each contrast, compute the difference in coefficients for ca, cb in contrasts: # Get coefficient for ca vs ref and cb vs ref param_a = f"C(condition)[T.{ca}]" param_b = f"C(condition)[T.{cb}]" # Handle ref_condition (intercept) if ca == ref_condition: coef_a = 0.0 se_a = 0.0 elif param_a in fit.params.index: coef_a = fit.params[param_a] se_a = fit.bse[param_a] else: continue if cb == ref_condition: coef_b = 0.0 se_b = 0.0 elif param_b in fit.params.index: coef_b = fit.params[param_b] se_b = fit.bse[param_b] else: continue # Contrast: coef_a - coef_b contrast_coef = coef_a - coef_b # Variance of difference: var(a) + var(b) - 2*cov(a,b) # Approximate: assume independence (conservative) contrast_se = np.sqrt(se_a**2 + se_b**2) z_score = contrast_coef / contrast_se if contrast_se > 0 else 0.0 from scipy.stats import norm pval = 2.0 * norm.sf(abs(z_score)) all_rows.append({ "fate": fate, "contrast": f"{ca} vs {cb}", "condition_a": ca, "condition_b": cb, "coef": float(contrast_coef), "std_err": float(contrast_se), "z_score": float(z_score), "pval": float(pval), }) except Exception as e: warnings.warn( f"Mixed model contrasts failed for fate '{fate}': {e}", stacklevel=2, ) if not all_rows: return pd.DataFrame() result_df = pd.DataFrame(all_rows) # FDR correction across all contrasts and fates pvals = result_df["pval"].values.copy() n = len(pvals) sorted_idx = np.argsort(pvals) sorted_pvals = pvals[sorted_idx] fdr = np.minimum(sorted_pvals * n / np.arange(1, n + 1), 1.0) for i in range(n - 2, -1, -1): fdr[i] = min(fdr[i], fdr[i + 1]) pval_adj = np.empty(n) pval_adj[sorted_idx] = fdr result_df["pval_adj"] = pval_adj result_df["significant"] = result_df["pval_adj"] < pval_threshold result_df = result_df.sort_values(["fate", "pval_adj"]).reset_index(drop=True) if verbose: print("\n=== Mixed-effects model contrasts ===") sig = result_df[result_df["significant"]] print(f" Significant contrasts: {len(sig)} / {len(result_df)}") if len(sig) > 0: print( sig[["fate", "contrast", "coef", "std_err", "z_score", "pval_adj"]] .to_string(index=False) ) return result_df
# ------------------------------------------------------------------ # Tier 3: Advanced — trajectory shift analysis # ------------------------------------------------------------------
[docs] def trajectory_shift( self, results: Dict[str, CommitmentScoreResult], pseudotime_key: str = "sccs_pseudotime", n_bootstrap: int = 500, seed: int = 42, verbose: bool = True, ) -> pd.DataFrame: """Test whether pseudotime distributions differ across conditions per fate arm. For each fate arm and each pair of conditions, computes: - Kolmogorov-Smirnov (KS) statistic and p-value - Wasserstein distance (Earth Mover's Distance) - Bootstrap CI on the Wasserstein distance Parameters ---------- results : dict Output of score_all_conditions(). pseudotime_key : str Column in adata_sub.obs with pseudotime values. n_bootstrap : int Bootstrap replicates for Wasserstein CI. Default 500. seed : int verbose : bool Returns ------- pd.DataFrame with columns: fate, comparison, ks_stat, ks_pval, wasserstein, wasserstein_ci_low, wasserstein_ci_high, mean_pt_A, mean_pt_B, delta_mean_pt, significant """ from scipy.stats import ks_2samp try: from scipy.stats import wasserstein_distance except ImportError: from scipy.stats import energy_distance as wasserstein_distance self._check_fitted() fate_names = self._scorer._fate_map.fate_names conditions = list(results.keys()) # Resolve pseudotime column if pseudotime_key not in self._scorer.adata_sub.obs: fallback = "velocity_pseudotime" if fallback in self._scorer.adata_sub.obs: warnings.warn( f"'{pseudotime_key}' not found. Using '{fallback}'. " "Run compute_local_pseudotime() for better results.", stacklevel=2, ) pseudotime_key = fallback else: raise ValueError( f"Neither '{pseudotime_key}' nor 'velocity_pseudotime' found " "in adata_sub.obs. Run compute_local_pseudotime() first." ) pt_all = np.array( self._scorer.adata_sub.obs[pseudotime_key], dtype=float ) cluster_labels = self._scorer.adata_sub.obs[self.obs_key].astype(str).values cond_labels = self._scorer.adata_sub.obs[self.condition_obs_key].astype(str).values rng = np.random.default_rng(seed) rows = [] for fate in fate_names: fate_mask = cluster_labels == str(fate) for ca, cb in combinations(conditions, 2): mask_a = fate_mask & (cond_labels == ca) mask_b = fate_mask & (cond_labels == cb) pt_a = pt_all[mask_a] pt_b = pt_all[mask_b] if len(pt_a) < 5 or len(pt_b) < 5: warnings.warn( f"Too few cells for fate '{fate}', " f"'{ca}' (n={len(pt_a)}) or '{cb}' (n={len(pt_b)}). " "Skipping.", stacklevel=2, ) continue # KS test ks_stat, ks_pval = ks_2samp(pt_a, pt_b) # Wasserstein distance w_obs = wasserstein_distance(pt_a, pt_b) # Bootstrap CI on Wasserstein w_boot = np.zeros(n_bootstrap) for b in range(n_bootstrap): ba = rng.choice(pt_a, size=len(pt_a), replace=True) bb = rng.choice(pt_b, size=len(pt_b), replace=True) w_boot[b] = wasserstein_distance(ba, bb) rows.append({ "fate": fate, "comparison": f"{ca} vs {cb}", "condition_a": ca, "condition_b": cb, "ks_stat": float(ks_stat), "ks_pval": float(ks_pval), "wasserstein": float(w_obs), "wasserstein_ci_low": float(np.percentile(w_boot, 2.5)), "wasserstein_ci_high": float(np.percentile(w_boot, 97.5)), "mean_pt_A": float(pt_a.mean()), "mean_pt_B": float(pt_b.mean()), "delta_mean_pt": float(pt_a.mean() - pt_b.mean()), "n_cells_A": int(len(pt_a)), "n_cells_B": int(len(pt_b)), }) if not rows: return pd.DataFrame() df = pd.DataFrame(rows) df["ks_pval_adj"] = np.minimum(df["ks_pval"] * len(df), 1.0) df["significant"] = df["ks_pval_adj"] < 0.05 df = df.sort_values(["fate", "ks_pval_adj"]).reset_index(drop=True) if verbose: print("\n=== Trajectory shift analysis ===") sig = df[df["significant"]] print(f" Significant shifts: {len(sig)} / {len(df)}") if len(sig) > 0: print( sig[[ "fate", "comparison", "ks_stat", "ks_pval_adj", "wasserstein", "delta_mean_pt", ]].to_string(index=False) ) return df
[docs] def plot_trajectory_shift( self, shift_df: pd.DataFrame, pseudotime_key: str = "sccs_pseudotime", color_map: Optional[Dict[str, str]] = None, figsize: Optional[Tuple[float, float]] = None, title: Optional[str] = None, save_path: Optional[str] = None, ) -> matplotlib.figure.Figure: """Visualize pseudotime distributions per condition per fate arm. Produces a grid of KDE plots: one row per fate arm, one column per pairwise comparison. Overlaid KDEs show how pseudotime distributions shift between conditions. Parameters ---------- shift_df : pd.DataFrame Output of trajectory_shift(). pseudotime_key : str color_map : dict, optional figsize : tuple, optional title : str, optional save_path : str, optional Returns ------- fig : matplotlib Figure """ import matplotlib.pyplot as plt import seaborn as sns self._check_fitted() if pseudotime_key not in self._scorer.adata_sub.obs: pseudotime_key = "velocity_pseudotime" pt_all = np.array( self._scorer.adata_sub.obs[pseudotime_key], dtype=float ) cluster_labels = self._scorer.adata_sub.obs[self.obs_key].astype(str).values cond_labels = self._scorer.adata_sub.obs[self.condition_obs_key].astype(str).values fate_names = shift_df["fate"].unique().tolist() comparisons = shift_df["comparison"].unique().tolist() if color_map is None: color_map = _condition_colors(self.conditions) n_fates = len(fate_names) n_comps = len(comparisons) if figsize is None: figsize = (n_comps * 4.0, n_fates * 3.0) sns.set_theme(style="ticks") fig, axes = plt.subplots( n_fates, n_comps, figsize=figsize, squeeze=False, ) for fi, fate in enumerate(fate_names): fate_mask = cluster_labels == str(fate) for ci, comp in enumerate(comparisons): ax = axes[fi][ci] row = shift_df[ (shift_df["fate"] == fate) & (shift_df["comparison"] == comp) ] if row.empty: ax.set_visible(False) continue ca = row["condition_a"].values[0] cb = row["condition_b"].values[0] pt_a = pt_all[fate_mask & (cond_labels == ca)] pt_b = pt_all[fate_mask & (cond_labels == cb)] sns.kdeplot( pt_a, ax=ax, color=color_map.get(ca, "#0072B2"), fill=True, alpha=0.35, label=ca, linewidth=1.5, ) sns.kdeplot( pt_b, ax=ax, color=color_map.get(cb, "#D55E00"), fill=True, alpha=0.35, label=cb, linewidth=1.5, ) ks_p = row["ks_pval_adj"].values[0] w = row["wasserstein"].values[0] sig_str = "*" if ks_p < 0.05 else "ns" ax.set_title( f"{fate}\nW={w:.3f} KS p={ks_p:.3f} {sig_str}", fontsize=9, ) ax.set_xlabel("Pseudotime", fontsize=8) ax.set_ylabel("Density" if ci == 0 else "", fontsize=8) if fi == 0 and ci == 0: ax.legend(fontsize=7, frameon=False) sns.despine(ax=ax) fig.suptitle( title or "Trajectory shift: pseudotime distributions by condition", fontsize=12, y=1.01, ) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# ------------------------------------------------------------------ # New visualizations # ------------------------------------------------------------------
[docs] def plot_omnibus_summary( self, omnibus_df: pd.DataFrame, results: Dict[str, CommitmentScoreResult], posthoc_df: Optional[pd.DataFrame] = None, figsize: Optional[Tuple[float, float]] = None, save_path: Optional[str] = None, vmin: Optional[float] = None, vmax: Optional[float] = None, ) -> matplotlib.figure.Figure: """Summary heatmap: fates × conditions showing omnibus significance. Left panel: heatmap of mean per-cell affinity per fate per condition, annotated with omnibus p-value stars. Right panel (if posthoc provided): significant pairwise comparisons. Parameters ---------- omnibus_df : pd.DataFrame Output of compare_omnibus(). results : dict Output of score_all_conditions(). posthoc_df : pd.DataFrame, optional Output of compare_posthoc(). figsize : tuple, optional save_path : str, optional vmin, vmax : float, optional Color limits for the mean-affinity heatmap. If both are ``None`` (default), they are derived from the finite values of the affinity matrix so the colormap spans the actual data range. Set explicitly to pin a fixed scale across figures. Returns ------- fig : matplotlib Figure """ import matplotlib.pyplot as plt import seaborn as sns fate_names = omnibus_df["fate"].tolist() conditions = list(results.keys()) # Build mean affinity matrix: fates × conditions # Slice per-condition from shared-embedding cell_scores when needed. _adata_sub = self._scorer.adata_sub mean_affinity = np.zeros((len(fate_names), len(conditions))) for j, fate in enumerate(fate_names): for i, cond in enumerate(conditions): if results[cond].cell_scores is None: continue fate_idx = results[cond].fate_names.index(fate) cs = results[cond].cell_scores if cs.shape[0] == _adata_sub.n_obs: m = (_adata_sub.obs[self.condition_obs_key].astype(str) == cond).values mean_affinity[j, i] = cs[m, fate_idx].mean() if m.any() else np.nan else: mean_affinity[j, i] = cs[:, fate_idx].mean() df_affinity = pd.DataFrame( mean_affinity, index=fate_names, columns=conditions ) # Star annotations from omnibus p-values def _pval_stars(p): if p < 0.001: return "***" if p < 0.01: return "**" if p < 0.05: return "*" return "" star_annot = np.empty_like(mean_affinity, dtype=object) for j, fate in enumerate(fate_names): row = omnibus_df[omnibus_df["fate"] == fate] p = row["pval_adj"].values[0] if len(row) > 0 else 1.0 stars = _pval_stars(p) for i in range(len(conditions)): star_annot[j, i] = stars n_panels = 2 if posthoc_df is not None and not posthoc_df.empty else 1 if figsize is None: # Use a wider default so the colorbar label (which now includes # the auto-derived [vmin, vmax]) and fate labels do not collide. figsize = (5.5 * n_panels + 1.0, max(3.5, len(fate_names) * 0.55 + 2.0)) sns.set_theme(style="ticks") fig, axes = plt.subplots(1, n_panels, figsize=figsize, squeeze=False) # Left panel: mean affinity heatmap ax = axes[0][0] # Auto-derive color limits from finite affinity values so the # colormap spans the realized data range rather than hard-pinning # to [0, 1] (which renders pale across most real datasets). finite_aff = mean_affinity[np.isfinite(mean_affinity)] if vmin is None: vmin_eff = float(finite_aff.min()) if finite_aff.size else 0.0 else: vmin_eff = float(vmin) if vmax is None: vmax_eff = float(finite_aff.max()) if finite_aff.size else 1.0 else: vmax_eff = float(vmax) if vmax_eff <= vmin_eff: vmax_eff = vmin_eff + 1e-9 cbar_label = ( f"Mean affinity [{vmin_eff:.2f}, {vmax_eff:.2f}]" ) sns.heatmap( df_affinity, ax=ax, cmap="YlOrRd", vmin=vmin_eff, vmax=vmax_eff, annot=star_annot, fmt="", linewidths=0.5, cbar_kws={"label": cbar_label, "shrink": 0.8}, annot_kws={"size": 12, "fontweight": "bold"}, ) ax.set_title("Mean affinity + omnibus significance", fontsize=10) ax.set_xlabel("Condition") ax.set_ylabel("Fate") ax.set_yticklabels(ax.get_yticklabels(), rotation=0) # Right panel: posthoc significance grid if n_panels == 2: ax = axes[0][1] if posthoc_df is not None and not posthoc_df.empty: sig = posthoc_df[posthoc_df["significant"]] if not sig.empty: # Build a fate × comparison matrix of significance all_comps = posthoc_df["comparison"].unique().tolist() sig_matrix = np.full((len(fate_names), len(all_comps)), "") for j, fate in enumerate(fate_names): for i, comp in enumerate(all_comps): row = sig[(sig["fate"] == fate) & (sig["comparison"] == comp)] if not row.empty: p = row["pval_adj"].values[0] sig_matrix[j, i] = _pval_stars(p) df_sig = pd.DataFrame(sig_matrix, index=fate_names, columns=all_comps) sns.heatmap( pd.DataFrame(0, index=fate_names, columns=all_comps), ax=ax, cmap="Greys", vmin=-1, vmax=1, annot=df_sig, fmt="", linewidths=0.5, cbar=False, annot_kws={"size": 12, "fontweight": "bold"}, ) ax.set_title("Post-hoc significance", fontsize=10) ax.set_xlabel("Comparison") ax.set_ylabel("Fate") ax.set_yticklabels(ax.get_yticklabels(), rotation=0) ax.tick_params(axis="x", rotation=30) else: ax.text(0.5, 0.5, "No significant\npost-hoc pairs", ha="center", va="center", fontsize=12) ax.set_visible(False) else: ax.text(0.5, 0.5, "No post-hoc\nresults provided", ha="center", va="center", fontsize=12) ax.set_visible(False) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
[docs] def plot_posthoc_heatmap( self, posthoc_df: pd.DataFrame, fate: Optional[str] = None, figsize: Optional[Tuple[float, float]] = None, save_path: Optional[str] = None, ) -> matplotlib.figure.Figure: """Condition × condition heatmap of post-hoc p-values for a given fate. Lower triangle: p-values. Upper triangle: delta mean affinity. Annotated with significance stars. Parameters ---------- posthoc_df : pd.DataFrame Output of compare_posthoc(). fate : str, optional Which fate to plot. If None, uses the first fate with significant results. figsize : tuple, optional save_path : str, optional Returns ------- fig : matplotlib Figure """ import matplotlib.pyplot as plt import seaborn as sns if fate is None: sig = posthoc_df[posthoc_df["significant"]] if sig.empty: print("[scCS] No significant post-hoc results to plot.") fig, ax = plt.subplots(figsize=(4, 3)) ax.text(0.5, 0.5, "No significant results", ha="center", va="center") return fig fate = sig["fate"].values[0] sub = posthoc_df[posthoc_df["fate"] == fate] if sub.empty: raise ValueError(f"No post-hoc results for fate '{fate}'.") # Get all conditions from comparisons conditions = set() for comp in sub["comparison"]: parts = comp.split(" vs ") conditions.update(parts) conditions = sorted(conditions) n = len(conditions) cond_idx = {c: i for i, c in enumerate(conditions)} # Build matrices pval_matrix = np.full((n, n), np.nan) delta_matrix = np.full((n, n), np.nan) for _, row in sub.iterrows(): parts = row["comparison"].split(" vs ") ca, cb = parts[0], parts[1] i, j = cond_idx[ca], cond_idx[cb] pval_matrix[i, j] = row["pval_adj"] pval_matrix[j, i] = row["pval_adj"] delta_matrix[i, j] = row["delta_mean"] delta_matrix[j, i] = -row["delta_mean"] # Build annotation def _pval_stars(p): if np.isnan(p): return "" if p < 0.001: return "***" if p < 0.01: return "**" if p < 0.05: return "*" return "ns" annot = np.empty((n, n), dtype=object) for i in range(n): for j in range(n): if i == j: annot[i, j] = "" elif i > j: # lower triangle: p-values annot[i, j] = _pval_stars(pval_matrix[i, j]) else: # upper triangle: delta mean d = delta_matrix[i, j] if np.isfinite(d): annot[i, j] = f"{d:+.3f}" else: annot[i, j] = "" if figsize is None: figsize = (max(4, n * 1.2), max(3.5, n * 1.0)) sns.set_theme(style="ticks") fig, ax = plt.subplots(figsize=figsize) # Use -log10(p) for color scale on lower triangle, delta for upper display_matrix = np.full((n, n), np.nan) for i in range(n): for j in range(n): if i > j and not np.isnan(pval_matrix[i, j]): display_matrix[i, j] = -np.log10(np.clip(pval_matrix[i, j], 1e-10, None)) elif i < j and np.isfinite(delta_matrix[i, j]): display_matrix[i, j] = delta_matrix[i, j] df_display = pd.DataFrame(display_matrix, index=conditions, columns=conditions) finite_vals = display_matrix[np.isfinite(display_matrix)] vmax = np.abs(finite_vals).max() if len(finite_vals) > 0 else 1.0 sns.heatmap( df_display, ax=ax, cmap="RdBu_r", center=0, vmin=-vmax, vmax=vmax, annot=annot, fmt="", linewidths=0.5, cbar_kws={"label": "-log10(p) / Δmean", "shrink": 0.8}, annot_kws={"size": 9}, ) ax.set_title(f"Post-hoc comparison — '{fate}'", fontsize=11) ax.set_xlabel("Condition") ax.set_ylabel("Condition") plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
[docs] def plot_pairwise_delta_grid( self, delta_results: Dict[Tuple[str, str], Dict], figsize_per_panel: Tuple[float, float] = (4, 4), cmap: str = "RdBu_r", save_path: Optional[str] = None, ) -> matplotlib.figure.Figure: """Grid of ΔCS heatmaps for all condition pairs. Each panel shows ΔnCS = nCS_A − nCS_B for one condition pair, with bootstrap CI half-width annotated below each entry. Inherits the same layout as :func:`scCS.plot.plot_delta_cs_heatmap` but renders all pairs on a single shared figure. Parameters ---------- delta_results : dict Output of compute_pairwise_deltas(). figsize_per_panel : tuple cmap : str Diverging colormap. Default 'RdBu_r'. save_path : str, optional Returns ------- fig : matplotlib Figure """ import matplotlib.pyplot as plt import seaborn as sns import pandas as pd n_pairs = len(delta_results) if n_pairs == 0: raise ValueError("No delta results to plot.") ncols = min(n_pairs, 3) nrows = int(np.ceil(n_pairs / ncols)) fig, axes = plt.subplots( nrows, ncols, figsize=(figsize_per_panel[0] * ncols, figsize_per_panel[1] * nrows), squeeze=False, ) for idx, ((ca, cb), delta_result) in enumerate(delta_results.items()): row, col = divmod(idx, ncols) ax = axes[row][col] delta = delta_result["delta_nCS"] ci_low = delta_result["ci_low"] ci_high = delta_result["ci_high"] fate_names = delta_result["fate_names"] cond_a = delta_result["condition_a"] cond_b = delta_result["condition_b"] k = len(fate_names) ci_half = (ci_high - ci_low) / 2.0 annot = np.empty((k, k), dtype=object) for i in range(k): for j in range(k): d = delta[i, j] h = ci_half[i, j] if np.isfinite(d) and np.isfinite(h): annot[i, j] = f"{d:+.2f}" + "\n" + f{h:.2f}" elif np.isfinite(d): annot[i, j] = f"{d:+.2f}" else: annot[i, j] = "—" df = pd.DataFrame(delta, index=fate_names, columns=fate_names) finite_vals = delta[np.isfinite(delta)] vmax = float(np.abs(finite_vals).max()) if finite_vals.size > 0 else 1.0 sns.heatmap( df, ax=ax, cmap=cmap, center=0, vmin=-vmax, vmax=vmax, annot=annot, fmt="", linewidths=0.5, cbar_kws={"label": "ΔnCS", "shrink": 0.7}, annot_kws={"size": 8}, ) ax.set_title(f"{cond_a}{cond_b}", fontsize=10) ax.set_xlabel("Reference fate (÷)", fontsize=9) ax.set_ylabel("Query fate (×)", fontsize=9) ax.tick_params(labelsize=8) # Hide unused axes for idx in range(n_pairs, nrows * ncols): row, col = divmod(idx, ncols) axes[row][col].set_visible(False) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# ------------------------------------------------------------------ # Convenience: transfer labels to full adata # ------------------------------------------------------------------
[docs] def transfer_labels( self, results: Dict[str, CommitmentScoreResult], prefix: str = "cs_", ) -> None: """Transfer per-cell commitment scores to the full adata for all conditions. Calls SingleScorer.transfer_labels() for each condition's result, writing condition-specific columns to adata.obs. Parameters ---------- results : dict Output of score_all_conditions(cell_level=True). prefix : str Column prefix. Default: 'cs_'. """ self._check_fitted() for cond, res in results.items(): cond_prefix = f"{prefix}{cond}_" self._scorer.transfer_labels(self.adata, res, prefix=cond_prefix)
# ------------------------------------------------------------------ # Plotting shortcuts (delegate to scorer) # ------------------------------------------------------------------
[docs] def plot_star(self, result: CommitmentScoreResult, **kwargs): """Radial star embedding plot.""" self._check_fitted() return self._scorer.plot_star(result, **kwargs)
[docs] def plot_star_grid( self, results: Dict[str, CommitmentScoreResult], color_map: Optional[Dict[str, str]] = None, figsize_per_panel: Tuple[float, float] = (6, 6), save_path: Optional[str] = None, ) -> matplotlib.figure.Figure: """Side-by-side star embedding plots, one per condition.""" import matplotlib.pyplot as plt from .plot import plot_star_embedding self._check_fitted() conditions = list(results.keys()) n = len(conditions) fig, axes = plt.subplots( 1, n, figsize=(figsize_per_panel[0] * n, figsize_per_panel[1]), squeeze=False, ) for i, cond in enumerate(conditions): plot_star_embedding( self._scorer.adata_sub, results[cond], color_by="fate", color_map=color_map, ax=axes[0][i], title=cond, ) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
[docs] def plot_rose_grid( self, results: Dict[str, CommitmentScoreResult], color_map: Optional[Dict[str, str]] = None, figsize_per_panel: Tuple[float, float] = (5, 5), title: Optional[str] = None, save_path: Optional[str] = None, ) -> matplotlib.figure.Figure: """Grid of polar rose plots — one per condition.""" from .plot import plot_rose_grid as _plot_rose_grid return _plot_rose_grid( results, color_map=color_map, figsize_per_panel=figsize_per_panel, title=title, save_path=save_path, )
[docs] def plot_affinity_distributions( self, results: Dict[str, CommitmentScoreResult], plot_type: Literal["violin", "box", "strip"] = "violin", color_map: Optional[Dict[str, str]] = None, figsize: Optional[Tuple[float, float]] = None, title: Optional[str] = None, save_path: Optional[str] = None, ) -> matplotlib.figure.Figure: """Violin/box plots of per-cell fate affinity scores by condition.""" import matplotlib.pyplot as plt import seaborn as sns fate_names = list(results.values())[0].fate_names k_fates = len(fate_names) conditions = list(results.keys()) # Build long-form affinity table. Slice cell_scores to each # condition's cells when results carry full-embedding arrays. _adata_sub = self._scorer.adata_sub rows = [] for cond, res in results.items(): if res.cell_scores is None: continue cs = res.cell_scores if cs.shape[0] == _adata_sub.n_obs: m = (_adata_sub.obs[self.condition_obs_key].astype(str) == cond).values cs = cs[m] for j, fate in enumerate(fate_names): for val in cs[:, j]: rows.append({"condition": cond, "fate": fate, "affinity": val}) df = pd.DataFrame(rows) if color_map is None: color_map = _condition_colors(conditions) ncols = min(k_fates, 3) nrows = int(np.ceil(k_fates / ncols)) if figsize is None: figsize = (ncols * 4.0, nrows * 3.5) sns.set_theme(style="ticks") fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False) for j, fate in enumerate(fate_names): row, col = divmod(j, ncols) ax = axes[row][col] sub = df[df["fate"] == fate] palette = [color_map.get(c, "#888888") for c in conditions] if plot_type == "violin": sns.violinplot(data=sub, x="condition", y="affinity", palette=palette, ax=ax, inner="box", order=conditions, cut=0) elif plot_type == "box": sns.boxplot(data=sub, x="condition", y="affinity", palette=palette, ax=ax, order=conditions, flierprops={"marker": ".", "markersize": 2, "alpha": 0.3}) else: sns.stripplot(data=sub, x="condition", y="affinity", palette=palette, ax=ax, order=conditions, size=2, alpha=0.4, jitter=True) ax.set_title(fate, fontsize=11, fontweight="bold") ax.set_xlabel("") ax.set_ylabel("Fate affinity" if col == 0 else "") ax.set_ylim(0, 1) ax.tick_params(axis="x", rotation=20) sns.despine(ax=ax) for j in range(k_fates, nrows * ncols): row, col = divmod(j, ncols) axes[row][col].set_visible(False) fig.suptitle( title or f"Per-cell fate affinity by condition ({plot_type})", fontsize=12, y=1.01, ) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
[docs] def plot_delta_cs_heatmap(self, delta_result: dict, **kwargs) -> matplotlib.figure.Figure: """Heatmap of ΔCS = nCS_A − nCS_B with CI annotation.""" from .plot import plot_delta_cs_heatmap as _plot return _plot(delta_result, **kwargs)
[docs] def plot_compare_conditions_bar( self, results: Dict[str, CommitmentScoreResult], **kwargs, ) -> matplotlib.figure.Figure: """Grouped bar chart of nCS per condition.""" from .plot import plot_compare_conditions_bar as _plot return _plot(results, **kwargs)
[docs] def plot_commitment_vector_radar( self, results: Dict[str, CommitmentScoreResult], **kwargs, ) -> matplotlib.figure.Figure: """Radar / spider chart of commitment vectors per condition.""" from .plot import plot_commitment_vector_radar as _plot return _plot(results, **kwargs)
# ------------------------------------------------------------------ # Representation # ------------------------------------------------------------------ def __repr__(self) -> str: status = "fitted" if self._fitted else "uninitialised" n_cond = len(self.conditions) return ( f"MultiScorer(" f"root='{self.root}', " f"branches={self.branches}, " f"conditions={self.conditions}, " f"status='{status}')" ) # ------------------------------------------------------------------ # Properties # ------------------------------------------------------------------ @property
[docs] def scorer(self) -> Optional[SingleScorer]: """The internal SingleScorer used for embedding and scoring.""" return self._scorer
@property
[docs] def adata_sub(self): """The embedding subset (from the internal SingleScorer).""" if self._scorer is not None: return self._scorer.adata_sub return None
@property
[docs] def is_fitted(self) -> bool: return self._fitted
def _check_embedding(self): if self._scorer is None or not self._scorer._embedding_built: raise RuntimeError( "Star embedding not built. Call build_embedding() first." ) def _check_fitted(self): if not self._fitted: raise RuntimeError( "MultiScorer is not fitted. Call fit() first." )