Source code for scCS.plot

"""
plot.py — Publication-quality visualizations for scCS.

Primary visualization: plot_star_embedding()
  Radial star layout with one arm per fate, cells colored by cluster,
  pseudotime, fate affinity, or commitment entropy.  Arm axes are drawn
  with fate labels at the tips.

Additional plots:
  plot_rose()                — polar rose of velocity magnitude by direction
  plot_pairwise_cs()         — heatmap of pairwise nCS/unCS matrix
  plot_commitment_bar()      — unCS/nCS bar chart per fate pair
  plot_commitment_heatmap()  — per-cell fate affinity heatmap
  plot_expression_trends()   — CellRank-style gene expression vs pseudotime
  plot_subset_comparison()   — multi-subset CS comparison
  plot_nn_entropy_elbow()    — elbow plot for k_nn selection

Multi-condition plots (PairScorer + MultiScorer):
  plot_delta_cs_heatmap()          — ΔCS heatmap with CI annotation
  plot_compare_conditions_bar()    — grouped bar chart of nCS per condition
  plot_commitment_vector_radar()   — radar chart of commitment vectors
  plot_omnibus_summary()           — fates × conditions heatmap with omnibus significance
  plot_posthoc_heatmap()           — condition × condition post-hoc p-value heatmap
  plot_pairwise_delta_grid()       — grid of ΔCS heatmaps for all condition pairs

Color maps
----------
All plot functions accept an optional ``color_map`` dict mapping fate name
to a hex color string.  Pass this to preserve your original cluster colors
from scanpy/Seurat across all scCS plots.  Progenitor cells always use
PROGENITOR_COLOR (gray) regardless of color_map.

Example::

    # Extract colors from scanpy
    color_map = dict(zip(
        adata.obs['cell_type'].cat.categories,
        adata.uns['cell_type_colors'],
    ))
    scorer.plot_star(result, color_map=color_map)

All plots use seaborn ticks theme.
Figures are returned as matplotlib Figure objects.
"""

from __future__ import annotations

import warnings
from typing import Dict, List, Optional, Tuple, Union

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.patheffects as pe
import numpy as np
import seaborn as sns

from .scores import CommitmentScoreResult

# Colorblind-friendly palette (Wong 2011)
[docs] FATE_PALETTE = [ "#0072B2", # blue "#D55E00", # vermillion "#009E73", # green "#CC79A7", # pink "#E69F00", # orange "#56B4E9", # sky blue "#F0E442", # yellow "#000000", # black ]
[docs] PROGENITOR_COLOR = "#AAAAAA" # neutral grey for bifurcation cluster
# Condition colors — distinct from fate colors, also colorblind-safe (Wong 2011 reordered)
[docs] CONDITION_PALETTE = [ "#E69F00", # orange "#56B4E9", # sky blue "#009E73", # green "#F0E442", # yellow "#0072B2", # blue "#D55E00", # vermillion "#CC79A7", # pink "#000000", # black "#5B51D3", # purple "#E6AB02", # dark gold "#A6CEE3", # light blue "#B2DF8A", # light green ]
def _fate_colors( fate_names: List[str], color_map: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: """Return a color dict for fate_names. If color_map is provided, use it for any fate names it contains and fall back to FATE_PALETTE for the rest. This lets users pass their original scanpy/Seurat cluster colors directly. """ out = {} palette_idx = 0 for name in fate_names: if color_map and name in color_map: out[name] = color_map[name] else: out[name] = FATE_PALETTE[palette_idx % len(FATE_PALETTE)] palette_idx += 1 return out def _condition_colors( condition_names: List[str], color_map: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: """Return a color dict for condition_names. Draws from CONDITION_PALETTE (distinct from FATE_PALETTE). If color_map is provided, use it for any condition names it contains and fall back to CONDITION_PALETTE for the rest. """ out = {} palette_idx = 0 for name in condition_names: if color_map and name in color_map: out[name] = color_map[name] else: out[name] = CONDITION_PALETTE[palette_idx % len(CONDITION_PALETTE)] palette_idx += 1 return out def _significance_stars(pval: float) -> str: """Return significance stars for a p-value.""" if pval < 0.001: return "***" elif pval < 0.01: return "**" elif pval < 0.05: return "*" else: return "ns" # --------------------------------------------------------------------------- # 1. Star embedding — primary visualization # ---------------------------------------------------------------------------
[docs] def plot_star_embedding( adata, result: CommitmentScoreResult, color_by: str = "fate", figsize: Tuple[float, float] = (8, 8), point_size: float = 8.0, alpha: float = 0.75, arm_color: str = "#CCCCCC", arm_linewidth: float = 1.5, arm_linestyle: str = "--", show_arm_labels: bool = True, show_velocity: bool = False, velocity_scale: float = 1.0, color_map: Optional[Dict[str, str]] = None, title: Optional[str] = None, vmin: Optional[float] = None, vmax: Optional[float] = None, cmap: Optional[str] = None, ax: Optional[plt.Axes] = None, save_path: Optional[str] = None, ) -> plt.Figure: """Radial star embedding plot — the primary scCS visualization. Draws the X_sccs embedding with: - Radial arm axes (dashed lines from origin to each fate tip) - Fate labels at the arm tips - Cells colored by fate, pseudotime, entropy, or per-fate affinity - Optional velocity arrows Parameters ---------- adata : AnnData Must have X_sccs in obsm. result : CommitmentScoreResult color_by : str What to color cells by: - ``"fate"`` — cluster/arm assignment (default) - ``"pseudotime"`` — reads ``sccs_pseudotime`` then ``velocity_pseudotime`` - ``"entropy"`` — per-cell commitment entropy (``cs_entropy``) - ``"nn_entropy"`` — NN-smoothed entropy (``cs_nn_entropy``; requires ``score(k_nn=...)``) - a fate name — per-cell affinity (``cs_{fate}``; requires ``score(cell_level=True)``) - any other str — auto-detected numeric or categorical column in ``adata.obs`` figsize : tuple point_size : float alpha : float arm_color : str Color of the radial arm guide lines. arm_linewidth : float arm_linestyle : str show_arm_labels : bool Draw fate name labels at arm tips. show_velocity : bool Overlay velocity arrows (requires velocity_sccs in obsm). velocity_scale : float Scale factor for velocity arrows. title : str, optional vmin, vmax : float, optional Color-scale limits for numeric ``color_by`` modes. Defaults to the finite data range, so structure is always visible regardless of the absolute entropy/affinity scale. Pass explicit values to pin limits for cross-figure comparison. cmap : str, optional Matplotlib colormap name. Defaults: ``"RdYlBu_r"`` for entropy, ``"viridis"`` for pseudotime/generic numeric, ``"Blues"`` for per-fate affinity. ax : matplotlib Axes, optional save_path : str, optional Returns ------- fig : matplotlib Figure """ sns.set_theme(style="ticks") if "X_sccs" not in adata.obsm: raise KeyError("X_sccs not found in adata.obsm. Run build_embedding() first.") coords = np.array(adata.obsm["X_sccs"]) sccs_meta = adata.uns.get("sccs", {}) arm_scale = sccs_meta.get("arm_scale", 10.0) arm_dirs = sccs_meta.get("arm_dirs", None) fate_names = result.fate_names colors = _fate_colors(fate_names, color_map) if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() # --- Draw radial arm axes --- if arm_dirs is not None: for j, (name, arm_dir) in enumerate(zip(fate_names, arm_dirs)): tip = arm_dir * arm_scale * 1.15 ax.plot( [0, tip[0]], [0, tip[1]], color=arm_color, linewidth=arm_linewidth, linestyle=arm_linestyle, zorder=1, ) if show_arm_labels: label_pos = arm_dir * arm_scale * 1.25 ax.text( label_pos[0], label_pos[1], name, ha="center", va="center", fontsize=11, fontweight="bold", color=colors[name], path_effects=[ pe.withStroke(linewidth=2, foreground="white") ], ) # Mark origin ax.scatter(0, 0, s=80, c="black", marker="+", zorder=5, linewidths=1.5) # --- Color cells --- _scatter_cells( ax, adata, coords, result, color_by, fate_names, colors, point_size, alpha, vmin=vmin, vmax=vmax, cmap=cmap, ) # --- Velocity arrows --- if show_velocity and "velocity_sccs" in adata.obsm: V = np.array(adata.obsm["velocity_sccs"]) # Subsample for readability n_arrows = min(300, adata.n_obs) idx = np.random.choice(adata.n_obs, n_arrows, replace=False) ax.quiver( coords[idx, 0], coords[idx, 1], V[idx, 0] * velocity_scale, V[idx, 1] * velocity_scale, alpha=0.5, color="black", scale=20, width=0.003, headwidth=4, headlength=5, zorder=4, ) # --- Formatting --- ax.set_aspect("equal") ax.set_xlabel("scCS dim 1", fontsize=10) ax.set_ylabel("scCS dim 2", fontsize=10) ax.set_title( title or f"scCS Star Embedding (bifurcation: cluster '{sccs_meta.get('root', '?')}')", fontsize=11, ) sns.despine(ax=ax) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
[docs] def plot_nn_entropy_elbow( scorer, k_nn_range: Union[List[int], range] = range(5, 51, 5), color_map: Optional[Dict[str, str]] = None, figsize: Tuple[float, float] = (12, 5), title: Optional[str] = None, save_path: Optional[str] = None, ) -> plt.Figure: """Elbow plots for choosing the optimal number of nearest neighbors (k_nn). Sweeps over ``k_nn_range``, computing NN-smoothed cell entropy at each k, and produces two side-by-side subplots: - **Left**: mean NN entropy across all cells vs k_nn. - **Right**: mean NN entropy per fate arm vs k_nn (one line per fate). Use these plots to identify the elbow — the k_nn where entropy stabilizes, indicating that additional smoothing no longer changes the signal. Parameters ---------- scorer : SingleScorer A fitted scorer with ``build_embedding()`` and ``fit()`` already called. No prior ``score()`` call is needed — cell scores are recomputed internally from the velocity vectors. k_nn_range : list or range k_nn values to sweep. Default: 5, 10, 15, ..., 50. color_map : dict, optional Fate name -> hex color. Falls back to the default FATE_PALETTE. figsize : tuple title : str, optional Overall figure title. Defaults to "NN Entropy Elbow". save_path : str, optional If provided, save figure to this path. Returns ------- fig : matplotlib Figure Examples -------- >>> scorer.build_embedding(differentiation_metric='pseudotime') >>> scorer.fit() >>> result = scorer.score(compute_cell_level=True) >>> fig = scorer.plot_nn_entropy_elbow() """ from .scores import compute_nn_cell_entropy, compute_cell_scores if scorer._fate_map is None or not scorer._fitted: raise RuntimeError("scorer must be fitted before plotting elbow.") if scorer._vx is None: raise RuntimeError("Velocity vectors not loaded. Call fit() or load_velocity_vectors().") fate_map = scorer._fate_map fate_names = fate_map.fate_names k_fates = fate_map.k coords = np.array(scorer.adata_sub.obsm["X_sccs"]) # Compute cell_scores once cell_scores = compute_cell_scores( scorer._vx, scorer._vy, fate_map.fate_centroids, fate_map.root_centroid, ) # Fate arm membership for per-fate means cluster_labels = scorer.adata_sub.obs[scorer.obs_key].astype(str).values fate_masks = { name: cluster_labels == name for name in fate_names } k_nn_list = list(k_nn_range) mean_all = [] mean_per_fate = {name: [] for name in fate_names} for k in k_nn_list: nn_ent = compute_nn_cell_entropy(cell_scores, coords, k) mean_all.append(nn_ent.mean()) for name in fate_names: mask = fate_masks[name] mean_per_fate[name].append(nn_ent[mask].mean() if mask.any() else float("nan")) # Colors colors = _fate_colors(fate_names, color_map) sns.set_theme(style="ticks") fig, axes = plt.subplots(1, 2, figsize=figsize) fig.suptitle(title or "NN Entropy Elbow", fontsize=13, y=1.01) # --- Left: overall mean --- ax = axes[0] ax.plot(k_nn_list, mean_all, color="#333333", linewidth=2, marker="o", markersize=5, label="All cells") ax.set_xlabel("k (nearest neighbors)", fontsize=11) ax.set_ylabel("Mean NN-smoothed entropy", fontsize=11) ax.set_title("Overall", fontsize=11) ax.set_xticks(k_nn_list) ax.tick_params(axis="x", rotation=45) sns.despine(ax=ax) # --- Right: per-fate means --- ax = axes[1] for name in fate_names: ax.plot(k_nn_list, mean_per_fate[name], color=colors[name], linewidth=2, marker="o", markersize=5, label=name) ax.set_xlabel("k (nearest neighbors)", fontsize=11) ax.set_ylabel("Mean NN-smoothed entropy", fontsize=11) ax.set_title("Per fate", fontsize=11) ax.set_xticks(k_nn_list) ax.tick_params(axis="x", rotation=45) ax.legend(frameon=False, fontsize=9) sns.despine(ax=ax) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# --------------------------------------------------------------------------- # 8. Expression trends along commitment axis # --------------------------------------------------------------------------- def _scatter_cells( ax, adata, coords, result, color_by, fate_names, colors, point_size, alpha, vmin=None, vmax=None, cmap=None, ): """Internal: scatter cells with the requested coloring scheme. Supports these ``color_by`` modes: - ``"fate"`` (default) — categorical, arm / progenitor / unassigned - ``"entropy"`` or ``"cs_entropy"`` — per-cell commitment entropy (``adata.obs["cs_entropy"]``) - ``"nn_entropy"`` or ``"cs_nn_entropy"`` — NN-smoothed entropy (``adata.obs["cs_nn_entropy"]``; requires ``score(k_nn=...)``) - ``"pseudotime"`` — pseudotime, reading ``sccs_pseudotime`` then ``velocity_pseudotime`` (falls back to gray with a warning if neither column is present) - a fate name in ``fate_names`` — per-cell affinity column ``cs_{fate}`` (requires ``score(cell_level=True)``) - any other ``adata.obs`` column — auto-detected numeric or categorical Numeric color scales auto-scale to the data range by default. Pass ``vmin`` / ``vmax`` to pin limits (useful for cross-figure comparison). Pass ``cmap`` to override the default colormap for that branch. """ sccs_meta = adata.uns.get("sccs", {}) bif_cluster = sccs_meta.get("root", None) def _auto_lim(values, vmin_in, vmax_in): """Compute (vmin, vmax) defaulting to data min/max when not provided.""" finite = np.asarray(values, dtype=float) finite = finite[np.isfinite(finite)] if finite.size == 0: return (vmin_in if vmin_in is not None else 0.0, vmax_in if vmax_in is not None else 1.0) lo = vmin_in if vmin_in is not None else float(np.nanmin(finite)) hi = vmax_in if vmax_in is not None else float(np.nanmax(finite)) if hi <= lo: hi = lo + 1e-6 return lo, hi if color_by == "fate": # Color by arm assignment (categorical) arm_names = adata.obs.get("sccs_branch", None) if arm_names is not None: arm_names = arm_names.astype(str).values # Bifurcation cluster if bif_cluster is not None: bif_mask = arm_names == str(bif_cluster) if bif_mask.sum() > 0: ax.scatter( coords[bif_mask, 0], coords[bif_mask, 1], c=PROGENITOR_COLOR, s=point_size, alpha=alpha, label=f"Progenitor ({bif_cluster})", zorder=2, rasterized=True, ) # Each fate arm for name in fate_names: mask = arm_names == str(name) if mask.sum() > 0: ax.scatter( coords[mask, 0], coords[mask, 1], c=colors[name], s=point_size, alpha=alpha, label=name, zorder=2, rasterized=True, ) # Unassigned unassigned = arm_names == "unassigned" if unassigned.sum() > 0: ax.scatter( coords[unassigned, 0], coords[unassigned, 1], c="#DDDDDD", s=point_size * 0.6, alpha=alpha * 0.5, label="other", zorder=1, rasterized=True, ) ax.legend( markerscale=2.5, fontsize=8, frameon=False, loc="upper right", bbox_to_anchor=(1.0, 1.0), ) else: ax.scatter(coords[:, 0], coords[:, 1], c="gray", s=point_size, alpha=alpha, rasterized=True) elif color_by in ("entropy", "cs_entropy"): col = "cs_entropy" if col not in adata.obs: warnings.warn("cs_entropy not in adata.obs. Run score() first.", stacklevel=3) ax.scatter(coords[:, 0], coords[:, 1], c="gray", s=point_size, alpha=alpha, rasterized=True) return vals = adata.obs[col].values.astype(float) lo, hi = _auto_lim(vals, vmin, vmax) sc = ax.scatter( coords[:, 0], coords[:, 1], c=vals, cmap=(cmap or "RdYlBu_r"), s=point_size, alpha=alpha, vmin=lo, vmax=hi, zorder=2, rasterized=True, ) plt.colorbar(sc, ax=ax, label="Commitment entropy", shrink=0.7, pad=0.02) elif color_by in ("nn_entropy", "cs_nn_entropy"): col = "cs_nn_entropy" if col not in adata.obs: warnings.warn( "cs_nn_entropy not in adata.obs. " "Call score(k_nn=...) with a positive k_nn to populate it.", stacklevel=3, ) ax.scatter(coords[:, 0], coords[:, 1], c="gray", s=point_size, alpha=alpha, rasterized=True) return vals = adata.obs[col].values.astype(float) lo, hi = _auto_lim(vals, vmin, vmax) sc = ax.scatter( coords[:, 0], coords[:, 1], c=vals, cmap=(cmap or "RdYlBu_r"), s=point_size, alpha=alpha, vmin=lo, vmax=hi, zorder=2, rasterized=True, ) plt.colorbar(sc, ax=ax, label="NN-smoothed entropy", shrink=0.7, pad=0.02) elif color_by == "pseudotime": # Try sccs_pseudotime first, then velocity_pseudotime col = None for candidate in ("sccs_pseudotime", "velocity_pseudotime"): if candidate in adata.obs: col = candidate break if col is None: warnings.warn( "Neither 'sccs_pseudotime' nor 'velocity_pseudotime' found in adata.obs. " "Run refit_pseudotime() or compute velocity pseudotime first.", stacklevel=3, ) ax.scatter(coords[:, 0], coords[:, 1], c="gray", s=point_size, alpha=alpha, rasterized=True) return vals = adata.obs[col].values.astype(float) lo, hi = _auto_lim(vals, vmin, vmax) sc = ax.scatter( coords[:, 0], coords[:, 1], c=vals, cmap=(cmap or "viridis"), s=point_size, alpha=alpha, vmin=lo, vmax=hi, zorder=2, rasterized=True, ) plt.colorbar(sc, ax=ax, label=col, shrink=0.7, pad=0.02) elif color_by in fate_names: # Per-fate affinity col = f"cs_{color_by}" if col not in adata.obs: warnings.warn( f"'{col}' not in adata.obs. Run score(cell_level=True) first.", stacklevel=3, ) ax.scatter(coords[:, 0], coords[:, 1], c="gray", s=point_size, alpha=alpha, rasterized=True) return vals = adata.obs[col].values.astype(float) lo, hi = _auto_lim(vals, vmin, vmax) sc = ax.scatter( coords[:, 0], coords[:, 1], c=vals, cmap=(cmap or "Blues"), s=point_size, alpha=alpha, vmin=lo, vmax=hi, zorder=2, rasterized=True, ) plt.colorbar(sc, ax=ax, label=f"Affinity: {color_by}", shrink=0.7, pad=0.02) else: # Generic: numeric or categorical column in adata.obs if color_by not in adata.obs: warnings.warn( f"'{color_by}' not found in adata.obs. Coloring by gray.", stacklevel=3, ) ax.scatter(coords[:, 0], coords[:, 1], c="gray", s=point_size, alpha=alpha, rasterized=True) return vals = adata.obs[color_by] try: vals_float = vals.astype(float).values lo, hi = _auto_lim(vals_float, vmin, vmax) sc = ax.scatter( coords[:, 0], coords[:, 1], c=vals_float, cmap=(cmap or "viridis"), s=point_size, alpha=alpha, vmin=lo, vmax=hi, zorder=2, rasterized=True, ) plt.colorbar(sc, ax=ax, label=color_by, shrink=0.7, pad=0.02) except (ValueError, TypeError): # Categorical categories = vals.astype("category").cat.categories cat_colors = {c: FATE_PALETTE[i % len(FATE_PALETTE)] for i, c in enumerate(categories)} for cat in categories: mask = vals.astype(str) == str(cat) ax.scatter( coords[mask, 0], coords[mask, 1], c=cat_colors[cat], s=point_size, alpha=alpha, label=str(cat), zorder=2, rasterized=True, ) ax.legend(markerscale=2.5, fontsize=8, frameon=False) # --------------------------------------------------------------------------- # 2. Multi-panel star embedding # ---------------------------------------------------------------------------
[docs] def plot_star_panels( adata, result: CommitmentScoreResult, panels: Optional[List[str]] = None, figsize_per_panel: Tuple[float, float] = (6, 6), point_size: float = 6.0, alpha: float = 0.75, color_map: Optional[Dict[str, str]] = None, save_path: Optional[str] = None, ) -> plt.Figure: """Multi-panel star embedding: one panel per coloring scheme. Default panels: fate assignment, pseudotime, entropy, + one per fate. Parameters ---------- adata : AnnData result : CommitmentScoreResult panels : list of str, optional List of color_by values. Defaults to ['fate', 'pseudotime', 'entropy'] + fate_names. figsize_per_panel : tuple point_size : float alpha : float save_path : str, optional Returns ------- fig : matplotlib Figure """ if panels is None: panels = ["fate", "pseudotime", "entropy"] + list(result.fate_names) n = len(panels) ncols = min(n, 3) nrows = int(np.ceil(n / ncols)) fig, axes = plt.subplots( nrows, ncols, figsize=(figsize_per_panel[0] * ncols, figsize_per_panel[1] * nrows), squeeze=False, ) for idx, (panel, ax) in enumerate(zip(panels, axes.ravel())): plot_star_embedding( adata, result, color_by=panel, point_size=point_size, alpha=alpha, color_map=color_map, ax=ax, title=panel, ) # Hide unused axes for ax in axes.ravel()[n:]: ax.set_visible(False) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# --------------------------------------------------------------------------- # 3. Rose / polar plot # ---------------------------------------------------------------------------
[docs] def plot_rose( result: CommitmentScoreResult, title: str = "Cumulative Velocity Magnitude by Direction", figsize: Tuple[float, float] = (7, 7), show_sectors: bool = True, color_map: Optional[Dict[str, str]] = None, ax: Optional[plt.Axes] = None, save_path: Optional[str] = None, ) -> plt.Figure: """Polar rose plot of cumulative velocity magnitudes per angular bin. Each bin shows the total velocity magnitude pointing in that direction. Fate sectors are shaded with distinct colors. Parameters ---------- result : CommitmentScoreResult title : str figsize : tuple show_sectors : bool ax : matplotlib Axes (polar), optional save_path : str, optional Returns ------- fig : matplotlib Figure """ sns.set_theme(style="ticks") n_bins = len(result.M_bin) bin_width = 2 * np.pi / n_bins bin_centers = np.linspace(0, 2 * np.pi, n_bins, endpoint=False) + bin_width / 2 colors = _fate_colors(result.fate_names, color_map) if ax is None: fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111, projection="polar") else: fig = ax.get_figure() bin_colors = ["#cccccc"] * n_bins if show_sectors: for j, (name, sector_bins) in enumerate(zip(result.fate_names, result.sectors)): for b in sector_bins: bin_colors[b] = colors[name] ax.bar( bin_centers, result.M_bin, width=bin_width * 0.9, color=bin_colors, alpha=0.85, edgecolor="white", linewidth=0.5, ) if result.fate_angles is not None: for j, (name, angle) in enumerate(zip(result.fate_names, result.fate_angles)): angle_rad = np.radians(angle) r_max = result.M_bin.max() * 1.15 ax.annotate( name, xy=(angle_rad, r_max), ha="center", va="center", fontsize=10, fontweight="bold", color=colors[name], ) patches = [ mpatches.Patch( color=colors[name], label=f"{name} (M={result.M_sector[j]:.1f})" ) for j, name in enumerate(result.fate_names) ] ax.legend(handles=patches, loc="upper right", bbox_to_anchor=(1.35, 1.1), fontsize=9) ax.set_title(title, pad=20, fontsize=12) ax.set_theta_zero_location("E") ax.set_theta_direction(1) ax.grid(True, alpha=0.3) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# --------------------------------------------------------------------------- # 3b. Per-condition rose grid # ---------------------------------------------------------------------------
[docs] def plot_rose_grid( results: Dict[str, "CommitmentScoreResult"], color_map: Optional[Dict[str, str]] = None, figsize_per_panel: Tuple[float, float] = (5, 5), title: Optional[str] = None, save_path: Optional[str] = None, ) -> plt.Figure: """Grid of polar rose plots — one per condition. All panels share the same radial scale (max of all M_bin.max() across conditions), making magnitudes directly comparable. Fate sectors are shaded with FATE_PALETTE colors (consistent with single-condition plot_rose). Parameters ---------- results : dict Mapping of condition_label -> CommitmentScoreResult (output of PairScorer.score_all_conditions()). color_map : dict, optional fate_name -> hex color. Falls back to FATE_PALETTE. figsize_per_panel : tuple Size of each polar subplot. title : str, optional Overall figure title. save_path : str, optional Returns ------- fig : matplotlib Figure """ sns.set_theme(style="ticks") conditions = list(results.keys()) n = len(conditions) if n == 0: raise ValueError("results dict is empty.") # Shared radial scale r_max = max(res.M_bin.max() for res in results.values()) r_max = r_max * 1.15 if r_max > 0 else 1.0 # Use fate_names from first result fate_names = list(results.values())[0].fate_names colors = _fate_colors(fate_names, color_map) fig = plt.figure(figsize=(figsize_per_panel[0] * n, figsize_per_panel[1])) for idx, cond in enumerate(conditions): res = results[cond] ax = fig.add_subplot(1, n, idx + 1, projection="polar") n_bins = len(res.M_bin) bin_width = 2 * np.pi / n_bins bin_centers = np.linspace(0, 2 * np.pi, n_bins, endpoint=False) + bin_width / 2 bin_colors = ["#cccccc"] * n_bins for j, (name, sector_bins) in enumerate(zip(res.fate_names, res.sectors)): for b in sector_bins: bin_colors[b] = colors[name] ax.bar( bin_centers, res.M_bin, width=bin_width * 0.9, color=bin_colors, alpha=0.85, edgecolor="white", linewidth=0.5, ) # Fate arm labels if res.fate_angles is not None: for j, (name, angle) in enumerate(zip(res.fate_names, res.fate_angles)): angle_rad = np.radians(angle) ax.annotate( name, xy=(angle_rad, r_max), ha="center", va="center", fontsize=9, fontweight="bold", color=colors[name], ) ax.set_ylim(0, r_max) ax.set_title(cond, pad=15, fontsize=11, fontweight="bold") ax.set_theta_zero_location("E") ax.set_theta_direction(1) ax.grid(True, alpha=0.3) # Shared legend (first panel's fate names) import matplotlib.patches as mpatches patches = [ mpatches.Patch(color=colors[name], label=name) for name in fate_names ] fig.legend( handles=patches, loc="lower center", ncol=min(len(fate_names), 4), fontsize=9, frameon=False, bbox_to_anchor=(0.5, -0.05), ) if title: fig.suptitle(title, fontsize=13, y=1.02) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# --------------------------------------------------------------------------- # 4. Pairwise CS heatmap # ---------------------------------------------------------------------------
[docs] def plot_pairwise_cs( result: CommitmentScoreResult, normalized: bool = True, title: Optional[str] = None, figsize: Optional[Tuple[float, float]] = None, cmap: str = "RdBu_r", save_path: Optional[str] = None, ) -> plt.Figure: """Heatmap of pairwise commitment scores. Entry [i, j] = CS(fate_i relative to fate_j). Values > 1 indicate stronger commitment to fate_i than fate_j. Color scale is log2-transformed for readability. Parameters ---------- result : CommitmentScoreResult normalized : bool Use nCS (True) or unCS (False). title : str, optional figsize : tuple, optional cmap : str save_path : str, optional Returns ------- fig : matplotlib Figure """ import pandas as pd mat = result.pairwise_nCS if normalized else result.pairwise_unCS df = pd.DataFrame(mat, index=result.fate_names, columns=result.fate_names) k = result.k if figsize is None: figsize = (max(4, k * 1.2), max(3.5, k * 1.0)) sns.set_theme(style="ticks") fig, ax = plt.subplots(figsize=figsize) log_mat = np.log2(np.clip(mat, 1e-6, None)) log_df = pd.DataFrame(log_mat, index=result.fate_names, columns=result.fate_names) finite_vals = log_mat[np.isfinite(log_mat)] vmax = np.abs(finite_vals).max() if len(finite_vals) > 0 else 5.0 sns.heatmap( log_df, ax=ax, cmap=cmap, center=0, vmin=-vmax, vmax=vmax, annot=df.round(2), fmt=".2f", linewidths=0.5, cbar_kws={"label": "log2(CS)", "shrink": 0.8}, ) label = "Normalized CS (nCS)" if normalized else "Unnormalized CS (unCS)" ax.set_title(title or f"Pairwise Commitment Scores ({label})", fontsize=11) ax.set_xlabel("Reference fate (denominator)") ax.set_ylabel("Query fate (numerator)") plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# --------------------------------------------------------------------------- # 5. Commitment bar chart # ---------------------------------------------------------------------------
[docs] def plot_commitment_bar( result: CommitmentScoreResult, ref_fate: Optional[str] = None, mode: str = "auto", color_map: Optional[Dict[str, str]] = None, title: Optional[str] = None, figsize: Optional[Tuple[float, float]] = None, save_path: Optional[str] = None, ) -> plt.Figure: """Bar chart of unCS and nCS for all k populations. For a k-furcation, produces k subplots — one per reference fate. Each subplot shows unCS (solid) and nCS (hatched) for all other k-1 fates relative to that reference. This way every population is shown as both a query and a reference, and nothing is hidden. For k=2 a single subplot is produced (equivalent to the old behaviour). Parameters ---------- result : CommitmentScoreResult ref_fate : str, optional If given, produce only a single subplot using this fate as reference. Useful when you want a focused comparison. mode : str Kept for backward compatibility; ignored. color_map : dict, optional Mapping of fate name → hex color. title : str, optional Overall figure title. figsize : tuple, optional Per-subplot size ``(w, h)``. Total figure width scales with k. save_path : str, optional Returns ------- fig : matplotlib Figure """ colors = _fate_colors(result.fate_names, color_map) sns.set_theme(style="ticks") # Decide which reference fates to show if ref_fate is not None: if ref_fate not in result.fate_names: raise ValueError( f"ref_fate '{ref_fate}' not in fate_names: " f"{result.fate_names}" ) ref_indices = [result.fate_names.index(ref_fate)] else: ref_indices = list(range(result.k)) n_panels = len(ref_indices) panel_w, panel_h = (figsize if figsize is not None else (4.0, 4.5)) fig, axes = plt.subplots( 1, n_panels, figsize=(panel_w * n_panels, panel_h), squeeze=False, ) for col, ref_idx in enumerate(ref_indices): ax = axes[0, col] ref_name = result.fate_names[ref_idx] query_names = [n for i, n in enumerate(result.fate_names) if i != ref_idx] query_idx = [i for i in range(result.k) if i != ref_idx] unCS_vals = [result.pairwise_unCS[i, ref_idx] for i in query_idx] nCS_vals = [result.pairwise_nCS[i, ref_idx] for i in query_idx] x = np.arange(len(query_names)) width = 0.35 bar_colors = [colors[n] for n in query_names] bars_un = ax.bar( x - width / 2, unCS_vals, width, color=bar_colors, alpha=0.90, label="unCS", edgecolor="white", linewidth=0.5, ) bars_n = ax.bar( x + width / 2, nCS_vals, width, color=bar_colors, alpha=0.55, hatch="///", edgecolor="white", linewidth=0.5, label="nCS", ) ax.axhline(1.0, color="black", linestyle="--", linewidth=1.0, alpha=0.45, label="CS = 1") # Value labels — compute y_max first all_vals = [v for v in unCS_vals + nCS_vals if np.isfinite(v)] y_max = max(all_vals) if all_vals else 2.0 pad = y_max * 0.03 for bar in bars_un: h = bar.get_height() if np.isfinite(h): ax.text(bar.get_x() + bar.get_width() / 2, h + pad, f"{h:.2f}", ha="center", va="bottom", fontsize=7.5) for bar in bars_n: h = bar.get_height() if np.isfinite(h): ax.text(bar.get_x() + bar.get_width() / 2, h + pad, f"{h:.2f}", ha="center", va="bottom", fontsize=7.5, color="#555555") ax.set_xticks(x) ax.set_xticklabels(query_names, rotation=15, ha="right") ax.set_ylim(0, y_max * 1.20) ax.set_ylabel(f"CS (÷ '{ref_name}')" if col == 0 else "") ax.set_title(f"vs '{ref_name}'", fontsize=10) if col == 0: ax.legend(frameon=False, fontsize=8) sns.despine(ax=ax) fig.suptitle( title or "Commitment scores — all populations", fontsize=11, y=1.02, ) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# --------------------------------------------------------------------------- # 6. Per-cell commitment heatmap # ---------------------------------------------------------------------------
[docs] def plot_commitment_heatmap( result: CommitmentScoreResult, cell_scores: Optional[np.ndarray] = None, max_cells: int = 500, title: str = "Per-Cell Fate Affinity", figsize: Optional[Tuple[float, float]] = None, save_path: Optional[str] = None, ) -> plt.Figure: """Heatmap of per-cell fate affinity scores (cells × fates). Parameters ---------- result : CommitmentScoreResult cell_scores : np.ndarray, shape (n_cells, k), optional If None, uses result.cell_scores. max_cells : int Subsample to this many cells for readability. title : str figsize : tuple, optional save_path : str, optional Returns ------- fig : matplotlib Figure """ import pandas as pd scores = cell_scores if cell_scores is not None else result.cell_scores if scores is None: raise ValueError( "cell_scores not available. Run scorer.score(compute_cell_level=True)." ) n_cells = scores.shape[0] if n_cells > max_cells: idx = np.random.choice(n_cells, max_cells, replace=False) idx = idx[np.argsort(np.argmax(scores[idx], axis=1))] scores_plot = scores[idx] else: idx = np.argsort(np.argmax(scores, axis=1)) scores_plot = scores[idx] if figsize is None: figsize = (max(4, result.k * 1.5), min(8, max(3, n_cells / 80))) sns.set_theme(style="ticks") fig, ax = plt.subplots(figsize=figsize) df = pd.DataFrame(scores_plot, columns=result.fate_names) sns.heatmap( df, ax=ax, cmap="Blues", vmin=0, vmax=1, xticklabels=True, yticklabels=False, cbar_kws={"label": "Fate affinity", "shrink": 0.8}, ) ax.set_title(title, fontsize=11) ax.set_xlabel("Fate") ax.set_ylabel(f"Cells (n={scores_plot.shape[0]}, sorted by dominant fate)") plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# --------------------------------------------------------------------------- # 7. Multi-subset comparison # ---------------------------------------------------------------------------
[docs] def plot_subset_comparison( subset_results: dict, ref_fate: Optional[str] = None, normalized: bool = True, title: str = "Commitment Score by Subset", figsize: Tuple[float, float] = (8, 4), save_path: Optional[str] = None, ) -> plt.Figure: """Compare commitment scores across multiple subsets. Subsets whose chosen reference pair yields ``inf`` (e.g. progenitor-only subsets with no fate-arm cells, so ``pairwise_nCS`` is undefined) are rendered as gray hatched placeholders at zero height with an ``"inf"`` annotation, instead of silently producing empty bars. Parameters ---------- subset_results : dict Mapping of subset_name -> CommitmentScoreResult (from ``SingleScorer.score_per_subset``). ref_fate : str, optional Reference fate for the CS column. If None, use the fate with smallest sector magnitude (most likely to be present in all subsets). normalized : bool If True use ``pairwise_nCS``, else ``pairwise_unCS``. title : str figsize : tuple save_path : str, optional Returns ------- fig : matplotlib Figure """ import pandas as pd rows = [] for subset_name, result in subset_results.items(): if ref_fate is None: ref_idx = int(np.argmin(result.M_sector)) else: ref_idx = result.fate_names.index(ref_fate) for j, fate_name in enumerate(result.fate_names): if j == ref_idx: continue cs_val = (result.pairwise_nCS[j, ref_idx] if normalized else result.pairwise_unCS[j, ref_idx]) rows.append({ "subset": subset_name, "fate": fate_name, "CS": float(cs_val), }) df = pd.DataFrame(rows) fate_names = df["fate"].unique().tolist() colors = _fate_colors(fate_names) # Warn about subsets that have all-inf values (progenitor-only) bad_subsets = ( df.groupby("subset")["CS"] .apply(lambda s: np.all(~np.isfinite(s))) ) progenitor_only = bad_subsets[bad_subsets].index.tolist() if progenitor_only: warnings.warn( f"Subsets with no fate-arm cells (pairwise_nCS = inf for all pairs): " f"{progenitor_only}. Rendered as hatched placeholders at y=0.", stacklevel=2, ) sns.set_theme(style="ticks") fig, ax = plt.subplots(figsize=figsize) subset_names = list(subset_results.keys()) x = np.arange(len(subset_names)) width = 0.8 / max(len(fate_names), 1) finite_max = 1.0 # for y-limit headroom for j, fate in enumerate(fate_names): sub_df = df[df["fate"] == fate].set_index("subset") vals = [] is_inf = [] for s in subset_names: if s in sub_df.index: v = float(sub_df.loc[s, "CS"]) if np.isfinite(v): vals.append(v) is_inf.append(False) finite_max = max(finite_max, v) else: vals.append(0.0) is_inf.append(True) else: vals.append(0.0) is_inf.append(False) offset = (j - len(fate_names) / 2 + 0.5) * width bars = ax.bar( x + offset, vals, width * 0.9, label=fate, color=colors[fate], alpha=0.85, edgecolor=colors[fate], linewidth=0.5, ) # Post-pass: hatched placeholder for inf bars for bar, inf_flag in zip(bars, is_inf): if inf_flag: bar.set_facecolor("#DDDDDD") bar.set_edgecolor("#888888") bar.set_hatch("///") bar.set_alpha(0.6) bar.set_height(0.04 * finite_max) # small visible stub ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.015 * finite_max, "inf", ha="center", va="bottom", fontsize=7, color="#555555", fontstyle="italic", ) ax.axhline(1.0, color="black", linestyle="--", linewidth=1, alpha=0.5) ax.set_xticks(x) ax.set_xticklabels(subset_names, rotation=15, ha="right") ax.set_ylabel("Commitment Score (CS)") ax.set_title(title) ax.legend(frameon=False) # Headroom for the "inf" labels ax.set_ylim(top=finite_max * 1.15) sns.despine(ax=ax) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# --------------------------------------------------------------------------- # 9. ΔCS heatmap with CI annotation # ---------------------------------------------------------------------------
[docs] def plot_delta_cs_heatmap( delta_result: dict, title: Optional[str] = None, figsize: Optional[Tuple[float, float]] = None, cmap: str = "RdBu_r", save_path: Optional[str] = None, ) -> plt.Figure: """Heatmap of ΔCS = nCS_A − nCS_B with CI annotation. Entry [i, j] = nCS_A(i÷j) − nCS_B(i÷j). Positive values (red) mean condition A has stronger commitment of fate i relative to fate j. Cells are annotated with Δ ± CI_half. Parameters ---------- delta_result : dict Output of PairScorer.compute_delta_CS(). title : str, optional figsize : tuple, optional cmap : str Diverging colormap. Default: 'RdBu_r'. save_path : str, optional Returns ------- fig : matplotlib Figure """ import pandas as pd delta = delta_result["delta_nCS"] ci_low = delta_result["ci_low"] ci_high = delta_result["ci_high"] fate_names = delta_result["fate_names"] cond_a = delta_result["condition_a"] cond_b = delta_result["condition_b"] k = len(fate_names) if figsize is None: figsize = (max(4, k * 1.5), max(3.5, k * 1.3)) # CI half-width for annotation ci_half = (ci_high - ci_low) / 2.0 # Build annotation matrix annot = np.empty((k, k), dtype=object) for i in range(k): for j in range(k): d = delta[i, j] h = ci_half[i, j] if np.isfinite(d) and np.isfinite(h): annot[i, j] = f"{d:+.2f}" + "\n" + f{h:.2f}" elif np.isfinite(d): annot[i, j] = f"{d:+.2f}" else: annot[i, j] = "—" df = pd.DataFrame(delta, index=fate_names, columns=fate_names) finite_vals = delta[np.isfinite(delta)] vmax = np.abs(finite_vals).max() if len(finite_vals) > 0 else 1.0 sns.set_theme(style="ticks") fig, ax = plt.subplots(figsize=figsize) sns.heatmap( df, ax=ax, cmap=cmap, center=0, vmin=-vmax, vmax=vmax, annot=annot, fmt="", linewidths=0.5, cbar_kws={"label": "ΔnCS", "shrink": 0.8}, annot_kws={"size": 9}, ) ax.set_title( title or f"ΔCS: '{cond_a}' − '{cond_b}'", fontsize=11, ) ax.set_xlabel(f"Reference fate (÷)") ax.set_ylabel(f"Query fate (×)") plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# --------------------------------------------------------------------------- # 10. Grouped bar chart of nCS per condition # ---------------------------------------------------------------------------
[docs] def plot_compare_conditions_bar( results: Dict[str, "CommitmentScoreResult"], ref_fate: Optional[str] = None, color_map: Optional[Dict[str, str]] = None, title: Optional[str] = None, figsize: Optional[Tuple[float, float]] = None, save_path: Optional[str] = None, ) -> plt.Figure: """Grouped bar chart of nCS per condition. For each fate pair (query ÷ reference), one group of bars — one bar per condition, colored by CONDITION_PALETTE. A horizontal dashed line at CS = 1 marks the neutral point. Parameters ---------- results : dict Mapping of condition_label -> CommitmentScoreResult (output of PairScorer.score_all_conditions()). ref_fate : str, optional Reference fate for the denominator. If None, uses the fate with the lowest mean M_sector across conditions. color_map : dict, optional condition_label -> hex color. Falls back to CONDITION_PALETTE. title : str, optional figsize : tuple, optional save_path : str, optional Returns ------- fig : matplotlib Figure """ import pandas as pd conditions = list(results.keys()) fate_names = list(results.values())[0].fate_names k = len(fate_names) # Resolve reference fate if ref_fate is None: mean_m = np.array([ np.mean([results[c].M_sector[j] for c in conditions]) for j in range(k) ]) ref_idx = int(np.argmin(mean_m)) ref_fate = fate_names[ref_idx] else: if ref_fate not in fate_names: raise ValueError(f"ref_fate '{ref_fate}' not in fate_names: {fate_names}") ref_idx = fate_names.index(ref_fate) query_names = [n for i, n in enumerate(fate_names) if i != ref_idx] query_idx = [i for i in range(k) if i != ref_idx] # Condition colors cond_colors = _condition_colors(conditions, color_map) # Build data rows = [] for cond in conditions: res = results[cond] for qi, qname in zip(query_idx, query_names): rows.append({ "condition": cond, "fate_pair": f"{qname} ÷ {ref_fate}", "nCS": res.pairwise_nCS[qi, ref_idx], }) df = pd.DataFrame(rows) fate_pairs = df["fate_pair"].unique().tolist() n_pairs = len(fate_pairs) n_conds = len(conditions) if figsize is None: figsize = (max(5, n_pairs * n_conds * 0.8), 4.5) sns.set_theme(style="ticks") fig, ax = plt.subplots(figsize=figsize) x = np.arange(n_pairs) width = 0.8 / n_conds for ci, cond in enumerate(conditions): sub = df[df["condition"] == cond] vals = [sub[sub["fate_pair"] == fp]["nCS"].values[0] if len(sub[sub["fate_pair"] == fp]) > 0 else np.nan for fp in fate_pairs] offset = (ci - n_conds / 2 + 0.5) * width bars = ax.bar( x + offset, vals, width * 0.9, color=cond_colors[cond], alpha=0.85, label=cond, edgecolor="white", linewidth=0.5, ) # Value labels for bar, v in zip(bars, vals): if np.isfinite(v): ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02, f"{v:.2f}", ha="center", va="bottom", fontsize=7.5, ) ax.axhline(1.0, color="black", linestyle="--", linewidth=1.0, alpha=0.45, label="CS = 1") ax.set_xticks(x) ax.set_xticklabels(fate_pairs, rotation=15, ha="right") ax.set_ylabel("nCS") ax.set_title(title or f"nCS by condition (÷ '{ref_fate}')", fontsize=11) ax.legend(frameon=False, fontsize=9) sns.despine(ax=ax) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# --------------------------------------------------------------------------- # 11. Radar / spider chart of commitment vectors # ---------------------------------------------------------------------------
[docs] def plot_commitment_vector_radar( results: Dict[str, "CommitmentScoreResult"], color_map: Optional[Dict[str, str]] = None, title: Optional[str] = None, figsize: Tuple[float, float] = (6, 6), save_path: Optional[str] = None, ) -> plt.Figure: """Radar / spider chart of commitment vectors per condition. Each condition is one closed polygon. Axes = fate names (k spokes). Values = commitment_vector (sums to 1). Conditions colored by CONDITION_PALETTE. For k < 3, falls back to a grouped bar chart with a warning. Parameters ---------- results : dict Mapping of condition_label -> CommitmentScoreResult (output of PairScorer.score_all_conditions()). color_map : dict, optional condition_label -> hex color. Falls back to CONDITION_PALETTE. title : str, optional figsize : tuple save_path : str, optional Returns ------- fig : matplotlib Figure """ import matplotlib.patches as mpatches conditions = list(results.keys()) fate_names = list(results.values())[0].fate_names k = len(fate_names) cond_colors = _condition_colors(conditions, color_map) if k < 3: warnings.warn( f"k={k} fates — radar chart requires k≥3. Falling back to bar chart.", stacklevel=2, ) # Simple bar chart fallback sns.set_theme(style="ticks") fig, ax = plt.subplots(figsize=figsize) x = np.arange(k) width = 0.8 / len(conditions) for ci, cond in enumerate(conditions): cv = results[cond].commitment_vector offset = (ci - len(conditions) / 2 + 0.5) * width ax.bar(x + offset, cv, width * 0.9, color=cond_colors[cond], alpha=0.85, label=cond) ax.set_xticks(x) ax.set_xticklabels(fate_names) ax.set_ylabel("Commitment weight") ax.set_title(title or "Commitment vectors by condition", fontsize=11) ax.legend(frameon=False) sns.despine(ax=ax) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig # Radar chart angles = np.linspace(0, 2 * np.pi, k, endpoint=False).tolist() angles += angles[:1] # close the polygon sns.set_theme(style="ticks") fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111, projection="polar") for cond in conditions: cv = list(results[cond].commitment_vector) cv += cv[:1] # close color = cond_colors[cond] ax.plot(angles, cv, color=color, linewidth=2.0, label=cond) ax.fill(angles, cv, color=color, alpha=0.15) # Spoke labels ax.set_xticks(angles[:-1]) ax.set_xticklabels(fate_names, fontsize=10) ax.set_ylim(0, 1) ax.set_yticks([0.25, 0.5, 0.75, 1.0]) ax.set_yticklabels(["0.25", "0.50", "0.75", "1.00"], fontsize=7, color="gray") ax.grid(True, alpha=0.3) ax.set_theta_zero_location("N") ax.set_theta_direction(-1) ax.legend( loc="upper right", bbox_to_anchor=(1.35, 1.1), fontsize=9, frameon=False, ) ax.set_title(title or "Commitment vectors by condition", pad=20, fontsize=12) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# --------------------------------------------------------------------------- # 12. Omnibus summary heatmap (MultiScorer) # ---------------------------------------------------------------------------
[docs] def plot_omnibus_summary( omnibus_df, results: Dict[str, CommitmentScoreResult], posthoc_df=None, figsize: Optional[Tuple[float, float]] = None, save_path: Optional[str] = None, ) -> plt.Figure: """Summary heatmap: fates × conditions showing omnibus significance. Left panel: heatmap of mean per-cell affinity per fate per condition, annotated with omnibus p-value stars. Right panel (if posthoc_df provided): significant pairwise comparisons as a connectivity grid. Parameters ---------- omnibus_df : pd.DataFrame Output of MultiScorer.compare_omnibus(). Columns: fate, test, statistic, pval, pval_adj, significant. results : dict Mapping of condition_label -> CommitmentScoreResult (output of MultiScorer.score_all_conditions()). posthoc_df : pd.DataFrame, optional Output of MultiScorer.compare_posthoc(). If provided, right panel shows post-hoc significance grid. figsize : tuple, optional save_path : str, optional Returns ------- fig : matplotlib Figure """ import pandas as pd conditions = list(results.keys()) fate_names = list(results.values())[0].fate_names n_conds = len(conditions) n_fates = len(fate_names) # Build mean affinity matrix: fates × conditions mean_affinity = np.zeros((n_fates, n_conds)) for ci, cond in enumerate(conditions): res = results[cond] if res.cell_scores is not None: for fi in range(n_fates): mean_affinity[fi, ci] = res.cell_scores[:, fi].mean() else: # Fall back to M_sector proportions total_m = res.M_sector.sum() if total_m > 0: mean_affinity[fi, ci] = res.M_sector[fi] / total_m # Build annotation with significance stars from omnibus_df annot = np.empty((n_fates, n_conds), dtype=object) omnibus_map = {} if omnibus_df is not None and len(omnibus_df) > 0: for _, row in omnibus_df.iterrows(): omnibus_map[row["fate"]] = row for fi, fate in enumerate(fate_names): for ci, cond in enumerate(conditions): val = mean_affinity[fi, ci] if fate in omnibus_map: pval_adj = omnibus_map[fate]["pval_adj"] stars = _significance_stars(pval_adj) annot[fi, ci] = f"{val:.3f}\n{stars}" else: annot[fi, ci] = f"{val:.3f}" # Determine layout has_posthoc = posthoc_df is not None and len(posthoc_df) > 0 n_panels = 2 if has_posthoc else 1 if figsize is None: panel_w = max(4, n_conds * 1.5) panel_h = max(3, n_fates * 0.8) figsize = (panel_w * n_panels + 1, panel_h) sns.set_theme(style="ticks") if has_posthoc: fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize, gridspec_kw={"width_ratios": [2, 1]}) else: fig, ax1 = plt.subplots(1, 1, figsize=figsize) ax2 = None # --- Left panel: mean affinity heatmap --- aff_df = pd.DataFrame(mean_affinity, index=fate_names, columns=conditions) sns.heatmap( aff_df, ax=ax1, cmap="YlOrRd", annot=annot, fmt="", linewidths=0.5, cbar_kws={"label": "Mean affinity", "shrink": 0.8}, annot_kws={"size": 9}, ) ax1.set_title("Mean affinity + omnibus significance", fontsize=11) ax1.set_xlabel("Condition") ax1.set_ylabel("Fate") # --- Right panel: post-hoc significance grid --- if ax2 is not None and posthoc_df is not None: # Count significant pairwise comparisons per fate sig_counts = {} for fate in fate_names: fate_posthoc = posthoc_df[posthoc_df["fate"] == fate] sig_counts[fate] = fate_posthoc["significant"].sum() if "significant" in fate_posthoc.columns else 0 # Build a simple bar chart of significant comparison counts per fate sig_vals = [sig_counts.get(f, 0) for f in fate_names] fate_colors = _fate_colors(fate_names) bar_colors = [fate_colors[f] for f in fate_names] ax2.barh(fate_names, sig_vals, color=bar_colors, alpha=0.85) ax2.set_xlabel("Significant pairwise comparisons") ax2.set_title("Post-hoc significance count", fontsize=11) ax2.invert_yaxis() sns.despine(ax=ax2) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# --------------------------------------------------------------------------- # 13. Post-hoc heatmap (MultiScorer) # ---------------------------------------------------------------------------
[docs] def plot_posthoc_heatmap( posthoc_df, fate: Optional[str] = None, figsize: Optional[Tuple[float, float]] = None, save_path: Optional[str] = None, ) -> plt.Figure: """Condition × condition heatmap of post-hoc p-values for a given fate. Lower triangle: p-values (color intensity). Upper triangle: delta mean affinity. Annotated with significance stars. Parameters ---------- posthoc_df : pd.DataFrame Output of MultiScorer.compare_posthoc(). Columns: fate, comparison, method, statistic, pval, pval_adj, significant, mean_A, mean_B, delta_mean. fate : str, optional Which fate to plot. If None, uses the first fate in posthoc_df. figsize : tuple, optional save_path : str, optional Returns ------- fig : matplotlib Figure """ import pandas as pd if posthoc_df is None or len(posthoc_df) == 0: raise ValueError("posthoc_df is empty or None.") # Resolve fate if fate is None: fate = posthoc_df["fate"].iloc[0] fate_df = posthoc_df[posthoc_df["fate"] == fate] if len(fate_df) == 0: raise ValueError(f"No post-hoc results for fate '{fate}'.") # Extract unique conditions from comparison strings all_conds = set() for comp in fate_df["comparison"]: parts = comp.split(" vs ") if len(parts) == 2: all_conds.add(parts[0].strip()) all_conds.add(parts[1].strip()) conditions = sorted(all_conds) n = len(conditions) cond_idx = {c: i for i, c in enumerate(conditions)} if n < 2: raise ValueError(f"Need at least 2 conditions for post-hoc heatmap, got {n}.") # Build matrices pval_matrix = np.full((n, n), np.nan) delta_matrix = np.full((n, n), np.nan) sig_matrix = np.full((n, n), False) for _, row in fate_df.iterrows(): comp = row["comparison"] parts = comp.split(" vs ") if len(parts) != 2: continue a, b = parts[0].strip(), parts[1].strip() ai, bi = cond_idx[a], cond_idx[b] pval_matrix[ai, bi] = row["pval_adj"] pval_matrix[bi, ai] = row["pval_adj"] delta_matrix[ai, bi] = row.get("delta_mean", np.nan) delta_matrix[bi, ai] = -row.get("delta_mean", np.nan) if not np.isnan(row.get("delta_mean", np.nan)) else np.nan sig_matrix[ai, bi] = row.get("significant", False) sig_matrix[bi, ai] = row.get("significant", False) # Build annotation: lower = p-value + stars, upper = delta + stars annot = np.empty((n, n), dtype=object) for i in range(n): for j in range(n): if i == j: annot[i, j] = "" elif i > j: # Lower triangle: p-value p = pval_matrix[i, j] if not np.isnan(p): stars = _significance_stars(p) annot[i, j] = f"{p:.3f}\n{stars}" else: annot[i, j] = "" else: # Upper triangle: delta mean d = delta_matrix[i, j] if not np.isnan(d): stars = _significance_stars(pval_matrix[i, j]) if not np.isnan(pval_matrix[i, j]) else "" annot[i, j] = f"{d:+.3f}\n{stars}" else: annot[i, j] = "" if figsize is None: figsize = (max(5, n * 1.5), max(4, n * 1.3)) sns.set_theme(style="ticks") fig, ax = plt.subplots(figsize=figsize) # Use pval_matrix for color scale (lower triangle) # Mask upper triangle for p-value coloring pval_display = np.copy(pval_matrix) pval_display[np.triu_indices(n)] = np.nan # mask upper # Plot p-value heatmap (lower triangle) log_pval = -np.log10(np.clip(pval_display, 1e-10, None)) log_df = pd.DataFrame(log_pval, index=conditions, columns=conditions) sns.heatmap( log_df, ax=ax, cmap="Reds", annot=annot, fmt="", linewidths=0.5, cbar_kws={"label": "-log10(adj p-value)", "shrink": 0.8}, annot_kws={"size": 9}, ) ax.set_title(f"Post-hoc comparisons — {fate}", fontsize=11) ax.set_xlabel("Condition") ax.set_ylabel("Condition") plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
# --------------------------------------------------------------------------- # 14. Pairwise delta grid (MultiScorer) # ---------------------------------------------------------------------------
[docs] def plot_pairwise_delta_grid( delta_results: Dict[Tuple[str, str], dict], figsize_per_panel: Tuple[float, float] = (4, 4), save_path: Optional[str] = None, ) -> plt.Figure: """Grid of ΔCS heatmaps for all condition pairs. Each panel shows the ΔnCS heatmap for one condition pair, using the same layout as plot_delta_cs_heatmap(). Parameters ---------- delta_results : dict Output of MultiScorer.compute_pairwise_deltas(). Mapping of (cond_a, cond_b) -> delta_result dict. figsize_per_panel : tuple Size of each subplot. save_path : str, optional Returns ------- fig : matplotlib Figure """ import pandas as pd pairs = list(delta_results.keys()) n_pairs = len(pairs) if n_pairs == 0: raise ValueError("delta_results is empty.") ncols = min(n_pairs, 3) nrows = int(np.ceil(n_pairs / ncols)) fig, axes = plt.subplots( nrows, ncols, figsize=(figsize_per_panel[0] * ncols, figsize_per_panel[1] * nrows), squeeze=False, ) for idx, (pair, delta_result) in enumerate(delta_results.items()): row, col = divmod(idx, ncols) ax = axes[row][col] delta = delta_result["delta_nCS"] ci_low = delta_result["ci_low"] ci_high = delta_result["ci_high"] fate_names = delta_result["fate_names"] cond_a = delta_result["condition_a"] cond_b = delta_result["condition_b"] k = len(fate_names) ci_half = (ci_high - ci_low) / 2.0 # Build annotation annot = np.empty((k, k), dtype=object) for i in range(k): for j in range(k): d = delta[i, j] h = ci_half[i, j] if np.isfinite(d) and np.isfinite(h): annot[i, j] = f"{d:+.2f}\n±{h:.2f}" elif np.isfinite(d): annot[i, j] = f"{d:+.2f}" else: annot[i, j] = "—" df = pd.DataFrame(delta, index=fate_names, columns=fate_names) finite_vals = delta[np.isfinite(delta)] vmax = np.abs(finite_vals).max() if len(finite_vals) > 0 else 1.0 sns.heatmap( df, ax=ax, cmap="RdBu_r", center=0, vmin=-vmax, vmax=vmax, annot=annot, fmt="", linewidths=0.5, cbar=False, annot_kws={"size": 7}, ) ax.set_title(f"ΔCS: {cond_a}{cond_b}", fontsize=9) ax.set_xlabel("Ref fate", fontsize=8) ax.set_ylabel("Query fate", fontsize=8) ax.tick_params(labelsize=7) # Hide unused axes for idx in range(n_pairs, nrows * ncols): row, col = divmod(idx, ncols) axes[row][col].set_visible(False) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig