Source code for scCS.pairwise

"""
pairwise.py — PairScorer: pairwise condition comparison for scCS.

Extends the single-condition SingleScorer to handle exactly 2 experimental
conditions (e.g., treatment vs. control, mutant vs. wild-type).

Key design principle: shared embedding
---------------------------------------
All conditions are embedded in a SINGLE shared star layout built on the pooled
data.  This is critical — if each condition had its own embedding, the arm
angles would differ and CS values would not be comparable across conditions.

Architecture
------------
PairScorer
    Wraps SingleScorer.  Pools both conditions for embedding, then scores
    each condition separately using cell masks on the shared embedding.

Tier 1 — Core pairwise API
    score_all_conditions()          : dict[condition -> CommitmentScoreResult]

Tier 2 — Statistical comparison
    compare_conditions()            : permutation test on per-cell fate affinity
    compute_delta_CS()              : ΔCS = nCS_A − nCS_B with bootstrap CI
    plot_affinity_distributions()   : violin/box plots of per-cell affinities

Tier 3 — Advanced
    fit_mixed_model()               : linear mixed-effects model on per-cell
                                      fate affinity scores via statsmodels MixedLM
    trajectory_shift()              : KS test + Wasserstein distance on
                                      pseudotime distributions per fate arm
    plot_trajectory_shift()         : visualization of pseudotime distributions

Usage
-----
>>> pscorer = scCS.PairScorer(
...     adata,
...     root='17',
...     branches=['homeostatic', 'activated'],
...     condition_obs_key='treatment',
...     obs_key='leiden',
... )
>>> pscorer.build_embedding(ordering_metric='pseudotime')
>>> pscorer.fit()
>>> results = pscorer.score_all_conditions()
>>> delta = pscorer.compute_delta_CS('control', 'treated')
>>> stats = pscorer.compare_conditions(results)
>>> pscorer.plot_affinity_distributions(results)
>>> shift = pscorer.trajectory_shift(results)
>>> pscorer.plot_trajectory_shift(shift)
"""

from __future__ import annotations

import warnings
from typing import Dict, List, Literal, Optional, Tuple, Union

import matplotlib.figure
import numpy as np
import pandas as pd

from .single import SingleScorer
from .scores import (
    CommitmentScoreResult,
    bootstrap_cs,
    compute_cell_scores,
    compute_magnitudes,
    compute_angles,
    bin_angles,
    compute_sector_magnitudes,
    compute_pairwise_cs_matrix,
    centroid_sectors,
    equal_sectors,
    compute_commitment_vector,
    compute_population_entropy,
    compute_mean_cell_entropy,
    compute_per_fate_cell_entropy,
    compute_nn_cell_entropy,
)
from .plot import _fate_colors, _condition_colors, CONDITION_PALETTE


# ---------------------------------------------------------------------------
# PairScorer
# ---------------------------------------------------------------------------

[docs] class PairScorer: """RNA velocity commitment scorer for pairwise (2-condition) experiments. Builds a SHARED star embedding on the pooled data from both conditions, then scores each condition separately. This ensures arm geometry is identical across conditions, making CS values directly comparable. Parameters ---------- adata : AnnData Full single-cell dataset containing both conditions. root : str Label of the progenitor/root cluster in adata.obs[obs_key]. branches : list of str Labels of the k terminal fate clusters. condition_obs_key : str Column in adata.obs with condition labels (e.g., 'treatment'). Must contain exactly 2 unique values. obs_key : str Column in adata.obs with cluster labels. Default: 'leiden'. n_angle_bins : int Number of angular bins. Default: 36. sector_method : {'centroid', 'equal'} Sector definition strategy. copy : bool Work on a copy of adata. Raises ------ ValueError If condition_obs_key has fewer or more than 2 unique values. For 3+ conditions, use MultiScorer instead. Examples -------- >>> pscorer = PairScorer( ... adata, ... root='17', ... branches=['homeostatic', 'activated'], ... condition_obs_key='treatment', ... obs_key='leiden', ... ) >>> pscorer.build_embedding(ordering_metric='pseudotime') >>> pscorer.fit() >>> results = pscorer.score_all_conditions() >>> delta = pscorer.compute_delta_CS('control', 'treated') >>> stats = pscorer.compare_conditions(results) """ def __init__( self, adata, root: str, branches: List[str], condition_obs_key: str, obs_key: str = "leiden", n_angle_bins: int = 36, sector_method: Literal["centroid", "equal"] = "centroid", copy: bool = False, ):
[docs] self.adata = adata.copy() if copy else adata
[docs] self.root = str(root)
[docs] self.branches = list(branches)
[docs] self.condition_obs_key = condition_obs_key
[docs] self.obs_key = obs_key
[docs] self.n_angle_bins = n_angle_bins
[docs] self.sector_method = sector_method
# Validate condition key if condition_obs_key not in adata.obs: raise ValueError( f"condition_obs_key='{condition_obs_key}' not found in adata.obs. " f"Available columns: {list(adata.obs.columns)}" )
[docs] self.conditions = sorted(adata.obs[condition_obs_key].astype(str).unique().tolist())
# Validate exactly 2 conditions if len(self.conditions) != 2: raise ValueError( f"PairScorer requires exactly 2 conditions, but " f"condition_obs_key='{condition_obs_key}' has {len(self.conditions)} " f"unique value(s): {self.conditions}. " + ("Use SingleScorer for single-condition analysis. " if len(self.conditions) == 1 else "Use MultiScorer for 3+ conditions. ") ) # Internal SingleScorer built on pooled data self._scorer: Optional[SingleScorer] = None self._fitted = False print( f"[scCS] PairScorer initialized.\n" f" Conditions ({len(self.conditions)}): {self.conditions}\n" f" Root: '{root}', " f"Branches: {branches}" ) # ------------------------------------------------------------------ # Step 1: Build shared embedding (delegates to SingleScorer) # ------------------------------------------------------------------
[docs] def build_embedding( self, ordering_metric: Union[str, np.ndarray] = "pseudotime", invert_ordering: bool = False, scale_ordering: bool = False, arm_scale: float = 10.0, jitter: float = 0.3, seed: int = 42, arm_norm: str = "global", verbose: bool = True, ) -> "PairScorer": """Build the shared star embedding on pooled data from both conditions. The embedding is built on ALL cells (both conditions pooled), ensuring that arm geometry is identical across conditions. Parameters ---------- ordering_metric : str or np.ndarray See SingleScorer.build_embedding(). invert_ordering : bool scale_ordering : bool arm_scale : float jitter : float seed : int verbose : bool Returns ------- self """ if verbose: print( f"[scCS] Building SHARED embedding on pooled data " f"({self.adata.n_obs} cells, {len(self.conditions)} conditions)..." ) self._scorer = SingleScorer( self.adata, root=self.root, branches=self.branches, obs_key=self.obs_key, n_angle_bins=self.n_angle_bins, sector_method=self.sector_method, copy=False, ) self._scorer.build_embedding( ordering_metric=ordering_metric, invert_ordering=invert_ordering, scale_ordering=scale_ordering, arm_scale=arm_scale, jitter=jitter, seed=seed, arm_norm=arm_norm, verbose=verbose, ) return self
[docs] def refit_pseudotime( self, scale_01: bool = True, arm_scale: float = 10.0, jitter: float = 0.3, seed: int = 42, arm_norm: str = "global", verbose: bool = True, ) -> "PairScorer": """Rebuild the shared embedding using subset-local pseudotime. See SingleScorer.refit_pseudotime(). """ self._check_embedding() self._scorer.refit_pseudotime( scale_01=scale_01, arm_scale=arm_scale, jitter=jitter, seed=seed, arm_norm=arm_norm, verbose=verbose, ) self._fitted = False return self
# ------------------------------------------------------------------ # Step 2: Fit (delegates to SingleScorer) # ------------------------------------------------------------------
[docs] def fit(self, verbose: bool = True) -> "PairScorer": """Fit the shared FateMap and project velocity. Must be called after build_embedding(). Returns ------- self """ self._check_embedding() self._scorer.fit(verbose=verbose) self._fitted = True return self
# ------------------------------------------------------------------ # Tier 1: Score all conditions # ------------------------------------------------------------------
[docs] def score_all_conditions( self, cell_level: bool = True, k_nn: Optional[int] = None, n_bootstrap: int = 0, bootstrap_ci: float = 0.95, verbose: bool = True, ) -> Dict[str, CommitmentScoreResult]: """Compute commitment scores separately for each condition. Uses the shared embedding and FateMap. Each condition's cells are masked from the shared adata_sub, so arm geometry is identical. Parameters ---------- cell_level : bool Compute per-cell fate affinity scores. k_nn : int, optional NN-smoothed entropy neighbors. n_bootstrap : int Bootstrap replicates for CI. 0 = disabled. bootstrap_ci : float CI level for bootstrap. verbose : bool Returns ------- dict : condition_label -> CommitmentScoreResult """ self._check_fitted() results: Dict[str, CommitmentScoreResult] = {} for cond in self.conditions: # Mask over adata_sub (shared embedding subset) mask = ( self._scorer.adata_sub.obs[self.condition_obs_key].astype(str) == cond ).values n_cond = mask.sum() if n_cond < 10: warnings.warn( f"Condition '{cond}' has only {n_cond} cells in the embedding. " "Skipping.", stacklevel=2, ) continue if verbose: print(f"\n[scCS] Scoring condition: '{cond}' ({n_cond} cells)...") results[cond] = self._scorer.score( cell_mask=mask, cell_level=cell_level, k_nn=k_nn, n_bootstrap=n_bootstrap, bootstrap_ci=bootstrap_ci, verbose=verbose, write_to_obs=False, ) return results
# ------------------------------------------------------------------ # Tier 2: Statistical comparison # ------------------------------------------------------------------
[docs] def compute_delta_CS( self, condition_a: str, condition_b: str, n_bootstrap: int = 500, ci: float = 0.95, seed: int = 42, verbose: bool = True, ) -> Dict: """Compute ΔCS = nCS_A − nCS_B with bootstrap confidence intervals. For each pair of fates (i, j), computes the difference in normalized commitment score between condition A and condition B, with a bootstrap CI obtained by resampling cells within each condition. Parameters ---------- condition_a, condition_b : str Condition labels (must be in self.conditions). n_bootstrap : int Number of bootstrap replicates. Default 500. ci : float Confidence interval level. Default 0.95. seed : int verbose : bool Returns ------- dict with keys: 'delta_nCS' : np.ndarray (k, k) — nCS_A − nCS_B 'ci_low' : np.ndarray (k, k) — lower CI bound on delta 'ci_high' : np.ndarray (k, k) — upper CI bound on delta 'nCS_A' : np.ndarray (k, k) — nCS for condition A 'nCS_B' : np.ndarray (k, k) — nCS for condition B 'fate_names' : list of str 'condition_a': str 'condition_b': str 'n_bootstrap': int 'ci_level' : float """ self._check_fitted() for cond in [condition_a, condition_b]: if cond not in self.conditions: raise ValueError( f"Condition '{cond}' not found. " f"Available: {self.conditions}" ) fate_map = self._scorer._fate_map vx = self._scorer._vx vy = self._scorer._vy # Masks over adata_sub mask_a = ( self._scorer.adata_sub.obs[self.condition_obs_key].astype(str) == condition_a ).values mask_b = ( self._scorer.adata_sub.obs[self.condition_obs_key].astype(str) == condition_b ).values # Sector definition (shared) if self.sector_method == "centroid": sectors, _ = centroid_sectors( fate_map.fate_centroids, fate_map.root_centroid, n_bins=self.n_angle_bins, ) else: sectors = equal_sectors(fate_map.k, n_bins=self.n_angle_bins) n_cells_a = np.array([ int(mask_a[idx].sum()) for idx in fate_map.fate_cell_indices ], dtype=float) n_cells_b = np.array([ int(mask_b[idx].sum()) for idx in fate_map.fate_cell_indices ], dtype=float) # Point estimates def _score_mask(mask, n_cells_per_fate): vx_m, vy_m = vx[mask], vy[mask] mag = compute_magnitudes(vx_m, vy_m) ang = compute_angles(vx_m, vy_m) _, M_bin = bin_angles(ang, mag, n_bins=self.n_angle_bins) M_sec = compute_sector_magnitudes(M_bin, sectors) return compute_pairwise_cs_matrix( M_sec, n_cells_per_fate=n_cells_per_fate, normalized=True ) nCS_A = _score_mask(mask_a, n_cells_a) nCS_B = _score_mask(mask_b, n_cells_b) delta = nCS_A - nCS_B # Bootstrap rng = np.random.default_rng(seed) alpha = (1.0 - ci) / 2.0 k = fate_map.k boot_deltas = np.zeros((n_bootstrap, k, k)) idx_a = np.where(mask_a)[0] idx_b = np.where(mask_b)[0] for b in range(n_bootstrap): # Resample within each condition boot_a = rng.choice(idx_a, size=len(idx_a), replace=True) boot_b = rng.choice(idx_b, size=len(idx_b), replace=True) mask_ba = np.zeros(len(vx), dtype=bool) mask_ba[boot_a] = True mask_bb = np.zeros(len(vx), dtype=bool) mask_bb[boot_b] = True nCS_ba = _score_mask(mask_ba, n_cells_a) nCS_bb = _score_mask(mask_bb, n_cells_b) boot_deltas[b] = nCS_ba - nCS_bb boot_deltas = np.where(np.isinf(boot_deltas), np.nan, boot_deltas) result = { "delta_nCS": delta, "ci_low": np.nanpercentile(boot_deltas, alpha * 100, axis=0), "ci_high": np.nanpercentile(boot_deltas, (1 - alpha) * 100, axis=0), "nCS_A": nCS_A, "nCS_B": nCS_B, "fate_names": fate_map.fate_names, "condition_a": condition_a, "condition_b": condition_b, "n_bootstrap": n_bootstrap, "ci_level": ci, } if verbose: ci_pct = int(ci * 100) print(f"\n=== ΔCS: '{condition_a}' − '{condition_b}' ===") df_delta = pd.DataFrame( delta, index=fate_map.fate_names, columns=fate_map.fate_names ) print(" ΔnCS (point estimate):") print(df_delta.round(3).to_string()) print(f"\n {ci_pct}% CI low:") print( pd.DataFrame( result["ci_low"], index=fate_map.fate_names, columns=fate_map.fate_names, ).round(3).to_string() ) print(f"\n {ci_pct}% CI high:") print( pd.DataFrame( result["ci_high"], index=fate_map.fate_names, columns=fate_map.fate_names, ).round(3).to_string() ) return result
[docs] def compare_conditions( self, results: Dict[str, CommitmentScoreResult], test: Literal["permutation", "kruskal"] = "permutation", n_permutations: int = 1000, pval_threshold: float = 0.05, seed: int = 42, verbose: bool = True, ) -> pd.DataFrame: """Statistical comparison of per-cell fate affinity scores across conditions. For PairScorer (k=2 conditions), the default test is a permutation test: shuffle condition labels, recompute mean per-cell affinity difference, and get an empirical null distribution. Parameters ---------- results : dict Output of score_all_conditions() with cell_level=True. test : {'permutation', 'kruskal'} Statistical test to use. Default: 'permutation' (recommended for k=2). n_permutations : int Number of permutations for the permutation test. Default 1000. pval_threshold : float Significance threshold. Default 0.05. seed : int verbose : bool Returns ------- pd.DataFrame with columns: fate, test, statistic, pval, pval_adj, significant [+ comparison column for pairwise tests] """ self._check_fitted() # Validate that cell_scores are available for cond, res in results.items(): if res.cell_scores is None: raise ValueError( f"cell_scores not available for condition '{cond}'. " "Re-run score_all_conditions(cell_level=True)." ) fate_names = list(results.values())[0].fate_names cond_a, cond_b = list(results.keys()) scores_a = results[cond_a].cell_scores scores_b = results[cond_b].cell_scores rows = [] if test == "permutation": rng = np.random.default_rng(seed) for j, fate in enumerate(fate_names): a_vals = scores_a[:, j] b_vals = scores_b[:, j] obs_diff = np.mean(a_vals) - np.mean(b_vals) # Permutation null pooled = np.concatenate([a_vals, b_vals]) n_a = len(a_vals) null_diffs = np.zeros(n_permutations) for p in range(n_permutations): perm = rng.permutation(pooled) null_diffs[p] = perm[:n_a].mean() - perm[n_a:].mean() pval = (np.abs(null_diffs) >= np.abs(obs_diff)).mean() rows.append({ "fate": fate, "comparison": f"{cond_a} vs {cond_b}", "test": "permutation", "statistic": float(obs_diff), "pval": float(pval), "mean_A": float(a_vals.mean()), "mean_B": float(b_vals.mean()), }) df = pd.DataFrame(rows) # Bonferroni correction across fates df["pval_adj"] = np.minimum(df["pval"] * len(fate_names), 1.0) df["significant"] = df["pval_adj"] < pval_threshold else: # kruskal from scipy.stats import kruskal, mannwhitneyu for j, fate in enumerate(fate_names): a_vals = scores_a[:, j] b_vals = scores_b[:, j] # Kruskal-Wallis (equivalent to Mann-Whitney for k=2) try: stat_kw, pval_kw = kruskal(a_vals, b_vals) except Exception: stat_kw, pval_kw = np.nan, np.nan rows.append({ "fate": fate, "comparison": f"{cond_a} vs {cond_b}", "test": "kruskal-wallis", "statistic": float(stat_kw), "pval": float(pval_kw), "mean_A": float(a_vals.mean()), "mean_B": float(b_vals.mean()), }) df = pd.DataFrame(rows) df["pval_adj"] = np.minimum(df["pval"] * len(fate_names), 1.0) df["significant"] = df["pval_adj"] < pval_threshold if verbose: print("\n=== Condition comparison ===") sig = df[df["significant"]] print(f" Test: {test} | Significant results: {len(sig)} / {len(df)}") if len(sig) > 0: print(sig[["fate", "comparison", "test", "statistic", "pval_adj"]].to_string(index=False)) else: print(f" No significant differences at pval_adj < {pval_threshold}.") return df
[docs] def plot_affinity_distributions( self, results: Dict[str, CommitmentScoreResult], plot_type: Literal["violin", "box", "strip"] = "violin", color_map: Optional[Dict[str, str]] = None, figsize: Optional[Tuple[float, float]] = None, title: Optional[str] = None, save_path: Optional[str] = None, ) -> matplotlib.figure.Figure: """Violin/box plots of per-cell fate affinity scores by condition. One panel per fate, showing the distribution of per-cell affinity scores split by condition. Parameters ---------- results : dict Output of score_all_conditions(cell_level=True). plot_type : {'violin', 'box', 'strip'} color_map : dict, optional condition_label -> hex color. figsize : tuple, optional title : str, optional save_path : str, optional Returns ------- fig : matplotlib Figure """ import matplotlib.pyplot as plt import seaborn as sns fate_names = list(results.values())[0].fate_names k_fates = len(fate_names) conditions = list(results.keys()) # Build long-form DataFrame rows = [] for cond, res in results.items(): if res.cell_scores is None: continue for j, fate in enumerate(fate_names): for val in res.cell_scores[:, j]: rows.append({"condition": cond, "fate": fate, "affinity": val}) df = pd.DataFrame(rows) # Colors for conditions if color_map is None: color_map = _condition_colors(conditions) ncols = min(k_fates, 3) nrows = int(np.ceil(k_fates / ncols)) if figsize is None: figsize = (ncols * 4.0, nrows * 3.5) sns.set_theme(style="ticks") fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False) for j, fate in enumerate(fate_names): row, col = divmod(j, ncols) ax = axes[row][col] sub = df[df["fate"] == fate] palette = [color_map.get(c, "#888888") for c in conditions] if plot_type == "violin": sns.violinplot( data=sub, x="condition", y="affinity", palette=palette, ax=ax, inner="box", order=conditions, cut=0, ) elif plot_type == "box": sns.boxplot( data=sub, x="condition", y="affinity", palette=palette, ax=ax, order=conditions, flierprops={"marker": ".", "markersize": 2, "alpha": 0.3}, ) else: # strip sns.stripplot( data=sub, x="condition", y="affinity", palette=palette, ax=ax, order=conditions, size=2, alpha=0.4, jitter=True, ) ax.set_title(fate, fontsize=11, fontweight="bold") ax.set_xlabel("") ax.set_ylabel("Fate affinity" if col == 0 else "") ax.set_ylim(0, 1) ax.tick_params(axis="x", rotation=20) sns.despine(ax=ax) for j in range(k_fates, nrows * ncols): row, col = divmod(j, ncols) axes[row][col].set_visible(False) fig.suptitle( title or f"Per-cell fate affinity by condition ({plot_type})", fontsize=12, y=1.01, ) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# ------------------------------------------------------------------ # Tier 3: Advanced — mixed-effects model # ------------------------------------------------------------------
[docs] def fit_mixed_model( self, results: Dict[str, CommitmentScoreResult], replicate_key: Optional[str] = None, ref_condition: Optional[str] = None, verbose: bool = True, ) -> pd.DataFrame: """Linear mixed-effects model on per-cell fate affinity scores. Models per-cell fate affinity as a function of condition (fixed effect) with optional sample/replicate as a random effect. Model (per fate j): affinity_ij ~ condition_i + (1 | sample_id_i) Uses statsmodels MixedLM. Parameters ---------- results : dict Output of score_all_conditions(cell_level=True). replicate_key : str, optional Column in adata_sub.obs with sample/replicate IDs. ref_condition : str, optional Reference condition for the fixed effect. verbose : bool Returns ------- pd.DataFrame with columns: fate, condition, coef, std_err, z_score, pval, pval_adj, ci_low, ci_high, significant """ try: import statsmodels.formula.api as smf except ImportError: raise ImportError( "statsmodels is required for mixed-effects modeling. " "pip install statsmodels" ) fate_names = list(results.values())[0].fate_names conditions = list(results.keys()) if ref_condition is None: ref_condition = sorted(conditions)[0] if ref_condition not in conditions: raise ValueError( f"ref_condition='{ref_condition}' not in conditions: " f"{conditions}" ) # Build long-form DataFrame with cell-level data rows = [] for cond, res in results.items(): if res.cell_scores is None: continue obs_names = res.cell_obs_names if res.cell_obs_names is not None else [] for i, obs_name in enumerate(obs_names): row = {"condition": cond, "obs_name": obs_name} for j, fate in enumerate(fate_names): row[f"affinity_{fate}"] = res.cell_scores[i, j] if replicate_key is not None and replicate_key in self._scorer.adata_sub.obs: row["sample_id"] = str( self._scorer.adata_sub.obs.loc[obs_name, replicate_key] if obs_name in self._scorer.adata_sub.obs_names else "unknown" ) else: row["sample_id"] = obs_name rows.append(row) df_long = pd.DataFrame(rows) # Set reference condition df_long["condition"] = pd.Categorical( df_long["condition"], categories=[ref_condition] + [c for c in conditions if c != ref_condition], ) all_rows = [] for fate in fate_names: col = f"affinity_{fate}" if col not in df_long.columns: continue try: if replicate_key is not None: model = smf.mixedlm( f"{col} ~ C(condition)", data=df_long, groups=df_long["sample_id"], ) fit = model.fit(reml=True, method="lbfgs") else: model = smf.ols(f"{col} ~ C(condition)", data=df_long) fit = model.fit() for param_name in fit.params.index: if "condition" not in param_name: continue cond_label = ( param_name .replace(f"C(condition)[T.", "") .replace("]", "") .strip() ) ci = fit.conf_int().loc[param_name] all_rows.append({ "fate": fate, "condition": cond_label, "reference": ref_condition, "coef": float(fit.params[param_name]), "std_err": float(fit.bse[param_name]), "z_score": float(fit.tvalues[param_name]), "pval": float(fit.pvalues[param_name]), "ci_low": float(ci.iloc[0]), "ci_high": float(ci.iloc[1]), }) except Exception as e: warnings.warn( f"Mixed model failed for fate '{fate}': {e}", stacklevel=2, ) if not all_rows: return pd.DataFrame() result_df = pd.DataFrame(all_rows) result_df["pval_adj"] = np.minimum(result_df["pval"] * len(result_df), 1.0) result_df["significant"] = result_df["pval_adj"] < 0.05 result_df = result_df.sort_values(["fate", "pval_adj"]).reset_index(drop=True) if verbose: print("\n=== Mixed-effects model results ===") sig = result_df[result_df["significant"]] print(f" Significant effects: {len(sig)} / {len(result_df)}") if len(sig) > 0: print( sig[["fate", "condition", "coef", "std_err", "pval_adj"]] .to_string(index=False) ) return result_df
# ------------------------------------------------------------------ # Tier 3: Advanced — trajectory shift analysis # ------------------------------------------------------------------
[docs] def trajectory_shift( self, results: Dict[str, CommitmentScoreResult], pseudotime_key: str = "sccs_pseudotime", n_bootstrap: int = 500, seed: int = 42, verbose: bool = True, ) -> pd.DataFrame: """Test whether pseudotime distributions differ across conditions per fate arm. For each fate arm, computes: - Kolmogorov-Smirnov (KS) statistic and p-value - Wasserstein distance (Earth Mover's Distance) - Bootstrap CI on the Wasserstein distance Parameters ---------- results : dict Output of score_all_conditions(). pseudotime_key : str Column in adata_sub.obs with pseudotime values. n_bootstrap : int Bootstrap replicates for Wasserstein CI. Default 500. seed : int verbose : bool Returns ------- pd.DataFrame with columns: fate, comparison, ks_stat, ks_pval, wasserstein, wasserstein_ci_low, wasserstein_ci_high, mean_pt_A, mean_pt_B, delta_mean_pt, significant """ from scipy.stats import ks_2samp try: from scipy.stats import wasserstein_distance except ImportError: from scipy.stats import energy_distance as wasserstein_distance self._check_fitted() fate_names = self._scorer._fate_map.fate_names conditions = list(results.keys()) # Resolve pseudotime column if pseudotime_key not in self._scorer.adata_sub.obs: fallback = "velocity_pseudotime" if fallback in self._scorer.adata_sub.obs: warnings.warn( f"'{pseudotime_key}' not found. Using '{fallback}'. " "Run compute_local_pseudotime() for better results.", stacklevel=2, ) pseudotime_key = fallback else: raise ValueError( f"Neither '{pseudotime_key}' nor 'velocity_pseudotime' found " "in adata_sub.obs. Run compute_local_pseudotime() first." ) pt_all = np.array( self._scorer.adata_sub.obs[pseudotime_key], dtype=float ) cluster_labels = self._scorer.adata_sub.obs[self.obs_key].astype(str).values cond_labels = self._scorer.adata_sub.obs[self.condition_obs_key].astype(str).values rng = np.random.default_rng(seed) rows = [] for fate in fate_names: fate_mask = cluster_labels == str(fate) for ca, cb in [(conditions[0], conditions[1])]: mask_a = fate_mask & (cond_labels == ca) mask_b = fate_mask & (cond_labels == cb) pt_a = pt_all[mask_a] pt_b = pt_all[mask_b] if len(pt_a) < 5 or len(pt_b) < 5: warnings.warn( f"Too few cells for fate '{fate}', " f"'{ca}' (n={len(pt_a)}) or '{cb}' (n={len(pt_b)}). " "Skipping.", stacklevel=2, ) continue # KS test ks_stat, ks_pval = ks_2samp(pt_a, pt_b) # Wasserstein distance w_obs = wasserstein_distance(pt_a, pt_b) # Bootstrap CI on Wasserstein w_boot = np.zeros(n_bootstrap) for b in range(n_bootstrap): ba = rng.choice(pt_a, size=len(pt_a), replace=True) bb = rng.choice(pt_b, size=len(pt_b), replace=True) w_boot[b] = wasserstein_distance(ba, bb) rows.append({ "fate": fate, "comparison": f"{ca} vs {cb}", "condition_a": ca, "condition_b": cb, "ks_stat": float(ks_stat), "ks_pval": float(ks_pval), "wasserstein": float(w_obs), "wasserstein_ci_low": float(np.percentile(w_boot, 2.5)), "wasserstein_ci_high": float(np.percentile(w_boot, 97.5)), "mean_pt_A": float(pt_a.mean()), "mean_pt_B": float(pt_b.mean()), "delta_mean_pt": float(pt_a.mean() - pt_b.mean()), "n_cells_A": int(len(pt_a)), "n_cells_B": int(len(pt_b)), }) if not rows: return pd.DataFrame() df = pd.DataFrame(rows) df["ks_pval_adj"] = np.minimum(df["ks_pval"] * len(df), 1.0) df["significant"] = df["ks_pval_adj"] < 0.05 df = df.sort_values(["fate", "ks_pval_adj"]).reset_index(drop=True) if verbose: print("\n=== Trajectory shift analysis ===") sig = df[df["significant"]] print(f" Significant shifts: {len(sig)} / {len(df)}") if len(sig) > 0: print( sig[[ "fate", "comparison", "ks_stat", "ks_pval_adj", "wasserstein", "delta_mean_pt", ]].to_string(index=False) ) return df
[docs] def plot_trajectory_shift( self, shift_df: pd.DataFrame, pseudotime_key: str = "sccs_pseudotime", color_map: Optional[Dict[str, str]] = None, figsize: Optional[Tuple[float, float]] = None, title: Optional[str] = None, save_path: Optional[str] = None, ) -> matplotlib.figure.Figure: """Visualize pseudotime distributions per condition per fate arm. Produces a grid of KDE plots: one row per fate arm. Overlaid KDEs show how pseudotime distributions shift between conditions. Parameters ---------- shift_df : pd.DataFrame Output of trajectory_shift(). pseudotime_key : str color_map : dict, optional figsize : tuple, optional title : str, optional save_path : str, optional Returns ------- fig : matplotlib Figure """ import matplotlib.pyplot as plt import seaborn as sns self._check_fitted() if pseudotime_key not in self._scorer.adata_sub.obs: pseudotime_key = "velocity_pseudotime" pt_all = np.array( self._scorer.adata_sub.obs[pseudotime_key], dtype=float ) cluster_labels = self._scorer.adata_sub.obs[self.obs_key].astype(str).values cond_labels = self._scorer.adata_sub.obs[self.condition_obs_key].astype(str).values fate_names = shift_df["fate"].unique().tolist() if color_map is None: color_map = _condition_colors(self.conditions) n_fates = len(fate_names) if figsize is None: figsize = (5.0, n_fates * 3.0) sns.set_theme(style="ticks") fig, axes = plt.subplots(n_fates, 1, figsize=figsize, squeeze=False) for fi, fate in enumerate(fate_names): ax = axes[fi][0] fate_mask = cluster_labels == str(fate) row = shift_df[shift_df["fate"] == fate] if row.empty: ax.set_visible(False) continue ca = row["condition_a"].values[0] cb = row["condition_b"].values[0] pt_a = pt_all[fate_mask & (cond_labels == ca)] pt_b = pt_all[fate_mask & (cond_labels == cb)] sns.kdeplot( pt_a, ax=ax, color=color_map.get(ca, "#0072B2"), fill=True, alpha=0.35, label=ca, linewidth=1.5, ) sns.kdeplot( pt_b, ax=ax, color=color_map.get(cb, "#D55E00"), fill=True, alpha=0.35, label=cb, linewidth=1.5, ) ks_p = row["ks_pval_adj"].values[0] w = row["wasserstein"].values[0] sig_str = "*" if ks_p < 0.05 else "ns" ax.set_title( f"{fate} — W={w:.3f} KS p={ks_p:.3f} {sig_str}", fontsize=10, ) ax.set_xlabel("Pseudotime", fontsize=8) ax.set_ylabel("Density", fontsize=8) if fi == 0: ax.legend(fontsize=7, frameon=False) sns.despine(ax=ax) fig.suptitle( title or "Trajectory shift: pseudotime distributions by condition", fontsize=12, y=1.01, ) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# ------------------------------------------------------------------ # Convenience: transfer labels to full adata # ------------------------------------------------------------------
[docs] def transfer_labels( self, results: Dict[str, CommitmentScoreResult], prefix: str = "cs_", ) -> None: """Transfer per-cell commitment scores to the full adata for all conditions. Calls SingleScorer.transfer_labels() for each condition's result, writing condition-specific columns to adata.obs. Parameters ---------- results : dict Output of score_all_conditions(cell_level=True). prefix : str Column prefix. Default: 'cs_'. """ self._check_fitted() for cond, res in results.items(): cond_prefix = f"{prefix}{cond}_" self._scorer.transfer_labels(self.adata, res, prefix=cond_prefix)
# ------------------------------------------------------------------ # Plotting shortcuts (delegate to scorer) # ------------------------------------------------------------------
[docs] def plot_star(self, result: CommitmentScoreResult, **kwargs): """Radial star embedding plot.""" self._check_fitted() return self._scorer.plot_star(result, **kwargs)
[docs] def plot_star_grid( self, results: Dict[str, CommitmentScoreResult], color_map: Optional[Dict[str, str]] = None, figsize_per_panel: Tuple[float, float] = (6, 6), save_path: Optional[str] = None, ) -> matplotlib.figure.Figure: """Side-by-side star embedding plots, one per condition. All panels share the same arm geometry and color scale. Parameters ---------- results : dict color_map : dict, optional figsize_per_panel : tuple save_path : str, optional Returns ------- fig : matplotlib Figure """ import matplotlib.pyplot as plt from .plot import plot_star_embedding self._check_fitted() conditions = list(results.keys()) n = len(conditions) fig, axes = plt.subplots( 1, n, figsize=(figsize_per_panel[0] * n, figsize_per_panel[1]), squeeze=False, ) for i, cond in enumerate(conditions): plot_star_embedding( self._scorer.adata_sub, results[cond], color_by="fate", color_map=color_map, ax=axes[0][i], title=cond, ) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
[docs] def plot_rose_grid( self, results: Dict[str, CommitmentScoreResult], color_map: Optional[Dict[str, str]] = None, figsize_per_panel: Tuple[float, float] = (5, 5), title: Optional[str] = None, save_path: Optional[str] = None, ) -> matplotlib.figure.Figure: """Grid of polar rose plots — one per condition. All panels share the same radial scale, making magnitudes directly comparable across conditions. Parameters ---------- results : dict color_map : dict, optional figsize_per_panel : tuple title : str, optional save_path : str, optional Returns ------- fig : matplotlib Figure """ from .plot import plot_rose_grid as _plot_rose_grid return _plot_rose_grid( results, color_map=color_map, figsize_per_panel=figsize_per_panel, title=title, save_path=save_path, )
[docs] def plot_delta_cs_heatmap(self, delta_result: dict, **kwargs) -> matplotlib.figure.Figure: """Heatmap of ΔCS = nCS_A − nCS_B with CI annotation. Parameters ---------- delta_result : dict Output of compute_delta_CS(). **kwargs Passed to :func:`scCS.plot.plot_delta_cs_heatmap`. Returns ------- fig : matplotlib Figure """ from .plot import plot_delta_cs_heatmap as _plot return _plot(delta_result, **kwargs)
[docs] def plot_compare_conditions_bar( self, results: Dict[str, CommitmentScoreResult], **kwargs, ) -> matplotlib.figure.Figure: """Grouped bar chart of nCS per condition. Parameters ---------- results : dict Output of score_all_conditions(). **kwargs Passed to :func:`scCS.plot.plot_compare_conditions_bar`. Returns ------- fig : matplotlib Figure """ from .plot import plot_compare_conditions_bar as _plot return _plot(results, **kwargs)
[docs] def plot_commitment_vector_radar( self, results: Dict[str, CommitmentScoreResult], **kwargs, ) -> matplotlib.figure.Figure: """Radar / spider chart of commitment vectors per condition. Parameters ---------- results : dict Output of score_all_conditions(). **kwargs Passed to :func:`scCS.plot.plot_commitment_vector_radar`. Returns ------- fig : matplotlib Figure """ from .plot import plot_commitment_vector_radar as _plot return _plot(results, **kwargs)
# ------------------------------------------------------------------ # Representation # ------------------------------------------------------------------ def __repr__(self) -> str: status = "fitted" if self._fitted else "uninitialised" n_cond = len(self.conditions) return ( f"PairScorer(" f"root='{self.root}', " f"branches={self.branches}, " f"conditions={self.conditions}, " f"status='{status}')" ) # ------------------------------------------------------------------ # Properties # ------------------------------------------------------------------ @property
[docs] def scorer(self) -> Optional[SingleScorer]: """The internal SingleScorer used for embedding and scoring.""" return self._scorer
@property
[docs] def adata_sub(self): """The embedding subset (from the internal SingleScorer).""" if self._scorer is not None: return self._scorer.adata_sub return None
@property
[docs] def is_fitted(self) -> bool: return self._fitted
def _check_embedding(self): if self._scorer is None or not self._scorer._embedding_built: raise RuntimeError( "Star embedding not built. Call build_embedding() first." ) def _check_fitted(self): if not self._fitted: raise RuntimeError( "PairScorer is not fitted. Call fit() first." )