Source code for scCS.embedding

"""
embedding.py — Radial star embedding for scCS.

Constructs a custom 2D layout where:
  - The bifurcation cluster (progenitor) sits at the origin (0, 0).
  - Each terminal fate population occupies its own radial arm, evenly
    spaced at 360/k degrees around the origin.
  - Within each arm, cells are ordered along the radial axis by a
    differentiation metric (pseudotime, CytoTRACE2, pathway score, etc.)
    so that less-differentiated cells are close to the center and
    more-differentiated cells are at the periphery.
  - ONLY cells belonging to the bifurcation cluster or a terminal fate
    are included.  All other populations are excluded from the embedding.

The result is stored in adata_sub.obsm['X_sccs'] on the returned subset
AnnData, and looks like a star or sunburst when plotted — one arm per
fate, radiating from the progenitor.

Velocity projection
-------------------
RNA velocity vectors (from scVelo) are projected into this custom 2D
space by computing the transition-probability-weighted displacement of
each cell in the scCS coordinate system.

Differentiation metrics supported
----------------------------------
- 'pseudotime'   : scVelo velocity_pseudotime (default)
- 'cytotrace'    : CytoTRACE2 score (column in adata.obs)
- 'custom'       : any per-cell numeric column in adata.obs
- np.ndarray     : directly supplied per-cell scores (shape n_cells,)

In all cases, higher score = more differentiated = farther from center.
If the metric is inverted (e.g., CytoTRACE2 where high = less
differentiated), pass invert_ordering=True.
"""

from __future__ import annotations
import anndata
import warnings
from typing import List, Optional, Tuple, Union

import numpy as np

try:
    import scvelo as scv
    _SCVELO_AVAILABLE = True
except ImportError:
    _SCVELO_AVAILABLE = False

try:
    import scanpy as sc
    _SCANPY_AVAILABLE = True
except ImportError:
    _SCANPY_AVAILABLE = False


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

[docs] def build_star_embedding( adata, root: str, branches: List[str], obs_key: str = "leiden", ordering_metric: Union[str, np.ndarray] = "pseudotime", invert_ordering: bool = False, arm_scale: float = 10.0, jitter: float = 0.3, seed: int = 42, arm_norm: str = "global", ) -> "anndata.AnnData": """Build the radial star embedding on a subset of adata. Only cells belonging to the bifurcation cluster or a terminal fate cluster are included. All other populations are excluded entirely. Parameters ---------- adata : AnnData Full dataset. Will NOT be modified. root : str Label of the progenitor/bifurcation cluster in adata.obs[obs_key]. These cells are placed at the origin. branches : list of str Labels of the k terminal fate populations. Each gets one radial arm. obs_key : str Column in adata.obs with cluster labels. ordering_metric : str or np.ndarray How to order cells along each arm: - 'pseudotime' : uses adata.obs['velocity_pseudotime'] (computed if absent) - 'cytotrace' : uses adata.obs['cytotrace2_score'] (must be pre-computed) - any str : uses adata.obs[ordering_metric] directly - np.ndarray : per-cell scores, shape (n_cells,) for the FULL adata Higher value = more differentiated = farther from center. invert_ordering : bool If True, invert the metric so that high values map to the center (use for metrics where high = less differentiated, e.g. raw CytoTRACE2). arm_scale : float Maximum radial distance (length of each arm). jitter : float Gaussian noise added perpendicular to each arm to avoid overplotting. seed : int Random seed for jitter. arm_norm : {"global", "per_arm"}, default "global" How to normalize the ordering metric onto the radial arms. The rescale formula ``(s - s_min) / (s_max - s_min) * arm_scale`` is only applied to fate cells (bifurcation cells sit at the origin); ``s_min`` and ``s_max`` are computed from fate cells only in both modes since v0.7.4, so the closest fate cell always maps to ``r ≈ 0`` and the furthest to ``r ≈ arm_scale``. - ``"global"`` (default, v0.7.3+): compute one ``(s_min, s_max) = (fate_scores.min(), fate_scores.max())`` over all fate cells and apply uniformly to every arm. Arms whose cells span shorter pseudotime intervals stay visibly shorter. Preserves the *relative* ordering of cells across arms — if Alpha cells span a wider pseudotime range than Delta cells, the Alpha arm extends further. Biologically meaningful: arm length reflects how far each fate has differentiated from the progenitor on a shared scale. - ``"per_arm"`` (legacy, pre-v0.7.3 default): each arm gets its own ``(s_min, s_max) = (fate_mask_scores.min(), fate_mask_scores.max())`` and is mapped to ``[0, arm_scale]`` independently. All arms reach the full ``arm_scale`` regardless of how compressed/extended their pseudotime range is. Provided for reproducibility of older plots. .. versionchanged:: 0.7.4 Both modes now compute ``s_min``/``s_max`` from fate cells only, instead of including bifurcation cells. This removes a visible gap between the origin and the start of each arm in v0.7.3 ``"global"`` mode. Returns ------- adata_sub : AnnData Subset containing ONLY bifurcation + terminal fate cells. Star embedding stored in adata_sub.obsm['X_sccs']. Metadata stored in adata_sub.uns['sccs']. """ import anndata rng = np.random.default_rng(seed) obs_labels_full = adata.obs[obs_key].astype(str).values # --- 0. Subset to relevant cells only --- keep_labels = set([str(root)] + [str(f) for f in branches]) keep_mask = np.array([l in keep_labels for l in obs_labels_full]) if keep_mask.sum() == 0: raise ValueError( f"No cells found matching root='{root}' " f"or branches={branches} in " f"adata.obs['{obs_key}']." ) # --- Resolve differentiation metric on the FULL adata BEFORE subsetting --- # This is critical for 'pseudotime': scVelo's velocity_pseudotime computation # requires the intact neighbor/velocity graph, which breaks after subsetting. # We resolve the metric on the full object, then slice to keep_mask. metric_for_sub: np.ndarray # will always be a pre-resolved array after this block if isinstance(ordering_metric, np.ndarray): arr = np.asarray(ordering_metric, dtype=float).ravel() if len(arr) != adata.n_obs: raise ValueError( f"Custom metric array has length {len(arr)}, " f"expected {adata.n_obs} (full adata)." ) metric_for_sub = arr[keep_mask] else: # Resolve on full adata (graph intact), then slice scores_full = _resolve_metric(adata, ordering_metric, invert_ordering) metric_for_sub = scores_full[keep_mask] adata_sub = adata[keep_mask].copy() obs_labels = adata_sub.obs[obs_key].astype(str).values n_cells = adata_sub.n_obs print(f"[scCS] Subsetting: {keep_mask.sum()} / {adata.n_obs} cells kept") print(f" ({adata.n_obs - keep_mask.sum()} cells from other populations excluded)") for lbl in sorted(keep_labels): n = (obs_labels == lbl).sum() role = "progenitor" if lbl == str(root) else "fate" print(f" {lbl}: {n} cells ({role})") # --- 1. Use the pre-resolved metric (already sliced to subset) --- # metric_for_sub is always a np.ndarray at this point (resolved above). # _fill_nan handles any remaining NaNs; inversion was already applied. scores = _fill_nan(np.asarray(metric_for_sub, dtype=float).ravel()) # --- 2. Compute arm directions (evenly spaced angles) --- k = len(branches) arm_angles_deg = np.linspace(0.0, 360.0, k, endpoint=False) arm_angles_rad = np.radians(arm_angles_deg) arm_dirs = np.stack([np.cos(arm_angles_rad), np.sin(arm_angles_rad)], axis=1) # (k, 2) # --- 3. Assign each cell to an arm --- # Bifurcation cells -> arm index -1 (origin) # Terminal fate cells -> their arm index arm_assignment = np.full(n_cells, -1, dtype=int) for j, fate in enumerate(branches): mask = obs_labels == str(fate) arm_assignment[mask] = j # --- 4. Compute arm score ranges for normalization --- # The rescale formula (s - s_min) / (s_max - s_min) * arm_scale is # only applied to FATE cells (bifurcation cells sit at origin via L162). # Computing s_min/s_max over fate cells only ensures the closest fate # cell maps to r=0 and the furthest to r=arm_scale, so each arm visibly # starts at the origin instead of mid-arm. bif_mask_sub = obs_labels == str(root) fate_mask_all = np.zeros(len(scores), dtype=bool) for fate in branches: fate_mask_all |= (obs_labels == str(fate)) if arm_norm == "global": # One range across all FATE cells; every arm uses it. Preserves the # relative pseudotime ordering of fate cells across arms — arms whose # cells span shorter pseudotime intervals stay visibly shorter. if fate_mask_all.sum() > 0: fate_scores = scores[fate_mask_all] g_min = float(fate_scores.min()) g_max = float(fate_scores.max()) else: g_min, g_max = float(scores.min()), float(scores.max()) arm_score_ranges = [(g_min, g_max)] * k elif arm_norm == "per_arm": # Each arm rescales by its OWN fate cells' (s_min, s_max). Every # arm visibly reaches the full radial cap regardless of underlying # pseudotime extent (legacy behavior; ignores cross-arm comparison). arm_score_ranges = [] for j, fate in enumerate(branches): fate_mask = obs_labels == str(fate) if fate_mask.sum() > 0: s = scores[fate_mask] arm_score_ranges.append((float(s.min()), float(s.max()))) else: arm_score_ranges.append((float(scores.min()), float(scores.max()))) else: raise ValueError( f"arm_norm must be 'global' or 'per_arm', got {arm_norm!r}." ) # --- 5. Place cells in 2D --- coords = np.zeros((n_cells, 2), dtype=float) # Bifurcation cluster: cluster at origin with small jitter n_bif = bif_mask_sub.sum() if n_bif > 0: coords[bif_mask_sub] = rng.normal(0.0, jitter * 0.5, size=(n_bif, 2)) # Fate cells: place along their assigned arm for j in range(k): cell_mask = arm_assignment == j if cell_mask.sum() == 0: continue s_min, s_max = arm_score_ranges[j] if s_max <= s_min: r = np.linspace(0.0, arm_scale, cell_mask.sum()) else: cell_scores_arm = scores[cell_mask] r = (cell_scores_arm - s_min) / (s_max - s_min) * arm_scale r = np.clip(r, 0.0, arm_scale) arm_dir = arm_dirs[j] positions = np.outer(r, arm_dir) perp_dir = np.array([-arm_dir[1], arm_dir[0]]) perp_noise = rng.normal(0.0, jitter, size=cell_mask.sum()) positions += np.outer(perp_noise, perp_dir) coords[cell_mask] = positions # --- 6. Store in subset adata --- adata_sub.obsm["X_sccs"] = coords if "sccs" not in adata_sub.uns: adata_sub.uns["sccs"] = {} adata_sub.uns["sccs"]["arm_angles_deg"] = arm_angles_deg adata_sub.uns["sccs"]["arm_dirs"] = arm_dirs adata_sub.uns["sccs"]["arm_scale"] = arm_scale adata_sub.uns["sccs"]["arm_norm"] = arm_norm adata_sub.uns["sccs"]["fate_names"] = [str(f) for f in branches] adata_sub.uns["sccs"]["root"] = str(root) adata_sub.uns["sccs"]["obs_key"] = obs_key # Store integer indices of kept cells in the original adata (for velocity projection) adata_sub.uns["sccs"]["parent_indices"] = np.where(keep_mask)[0] adata_sub.obs["sccs_arm"] = arm_assignment adata_sub.obs["sccs_branch"] = [ str(branches[a]) if a >= 0 else str(root) for a in arm_assignment ] print( f'\n[scCS] Star embedding built → adata_sub.obsm["X_sccs"] shape: {coords.shape}' ) print( f' Arm angles: ' + str({str(f): round(float(a), 1) for f, a in zip(branches, arm_angles_deg)}) ) return adata_sub
[docs] def project_velocity_star( adata_sub, adata_full=None, verbose: bool = True, ) -> Tuple[np.ndarray, np.ndarray]: """Project RNA velocity into the scCS star embedding space. Uses the transition probability matrix from the full (unsubsetted) adata to compute the expected displacement of each subset cell in the X_sccs coordinate system. This is necessary because subsetting breaks the velocity/neighbor graph matrices (they retain full-dataset dimensions). We always use the full graph and restrict to subset cell indices. Parameters ---------- adata_sub : AnnData Subset returned by build_star_embedding(). Must have X_sccs in obsm and a 'sccs_parent_indices' entry in uns (set automatically). adata_full : AnnData, optional The original full dataset with intact velocity_graph in uns. If None, falls back to using adata_sub directly (only works if velocity_graph was computed on the subset). Returns ------- vx, vy : np.ndarray, shape (n_sub_cells,) Velocity components in the scCS embedding. Also stored in adata_sub.obsm['velocity_sccs']. """ if "X_sccs" not in adata_sub.obsm: raise ValueError( "X_sccs embedding not found. Run build_star_embedding() first." ) coords_sub = np.array(adata_sub.obsm["X_sccs"]) # (n_sub, 2) n_sub = adata_sub.n_obs # Retrieve the parent indices (positions in full adata) stored during subsetting parent_idx = adata_sub.uns.get("sccs", {}).get("parent_indices", None) # ── Strategy 1: scVelo velocity_embedding on the full adata ────────────── # Run on full adata, then slice to subset rows. This is the most accurate. if _SCVELO_AVAILABLE and adata_full is not None and "velocity_graph" in adata_full.uns: if verbose: print("[scCS] Projecting velocity via scVelo on full adata → slicing to subset...") # Temporarily inject X_sccs into full adata for all cells. # Subset cells get their star coords; other cells get zeros (ignored after slicing). n_full = adata_full.n_obs coords_full = np.zeros((n_full, 2), dtype=float) if parent_idx is not None: coords_full[parent_idx] = coords_sub else: # Fallback: match by obs_names sub_names = set(adata_sub.obs_names) full_names = list(adata_full.obs_names) idx_map = [i for i, n in enumerate(full_names) if n in sub_names] coords_full[idx_map] = coords_sub adata_full.obsm["X_sccs_tmp"] = coords_full try: scv.tl.velocity_embedding(adata_full, basis="sccs_tmp") V_full = np.array(adata_full.obsm["velocity_sccs_tmp"]) # (n_full, 2) # Slice to subset if parent_idx is not None: V_sub = V_full[parent_idx] else: V_sub = V_full[idx_map] vx, vy = V_sub[:, 0], V_sub[:, 1] adata_sub.obsm["velocity_sccs"] = V_sub if verbose: print(f"[scCS] Velocity projected. Shape: {V_sub.shape}") return vx, vy except Exception as e: warnings.warn( f"scVelo velocity_embedding on full adata failed ({e}). " "Falling back to graph-based projection.", RuntimeWarning, stacklevel=2, ) finally: # Always clean up temp keys, even if velocity_embedding raised adata_full.obsm.pop("X_sccs_tmp", None) adata_full.obsm.pop("velocity_sccs_tmp", None) # ── Strategy 2: graph-based projection using full adata's velocity_graph ── # Manually compute T[sub, :][:, sub] × coords_sub - coords_sub if adata_full is not None and "velocity_graph" in adata_full.uns: if verbose: print("[scCS] Using graph-based projection from full velocity_graph...") try: import scipy.sparse as sp T_full = adata_full.uns["velocity_graph"] if not sp.issparse(T_full): T_full = sp.csr_matrix(T_full) if parent_idx is None: sub_names = set(adata_sub.obs_names) full_names = list(adata_full.obs_names) parent_idx = np.array([i for i, n in enumerate(full_names) if n in sub_names]) # Extract sub × sub block of the transition matrix T_sub = T_full[parent_idx, :][:, parent_idx] # (n_sub, n_sub) # Row-normalize row_sums = np.array(T_sub.sum(axis=1)).ravel() row_sums[row_sums == 0] = 1.0 T_norm = sp.diags(1.0 / row_sums) @ T_sub expected = T_norm @ coords_sub # (n_sub, 2) V_sub = expected - coords_sub vx, vy = V_sub[:, 0], V_sub[:, 1] adata_sub.obsm["velocity_sccs"] = V_sub if verbose: print(f"[scCS] Graph-based velocity projected. Shape: {V_sub.shape}") return vx, vy except Exception as e: warnings.warn( f"Graph-based projection from full adata failed ({e}). " "Falling back to subset-only projection.", RuntimeWarning, stacklevel=2, ) # ── Strategy 3: last resort — use whatever graph is in adata_sub ───────── if verbose: warnings.warn( "No full adata provided and no compatible velocity_graph found. " "Using subset-only graph (may have dimension issues). " "Pass adata_full=adata to project_velocity() for best results.", RuntimeWarning, stacklevel=2, ) vx, vy = _graph_velocity_projection(adata_sub, coords_sub, verbose=verbose) adata_sub.obsm["velocity_sccs"] = np.stack([vx, vy], axis=1) return vx, vy
[docs] def run_velocity_pipeline( adata, mode: str = "dynamical", n_top_genes: int = 2000, n_pcs: int = 30, n_neighbors: int = 30, min_shared_counts: int = 20, verbose: bool = True, ) -> None: """Run the full scVelo RNA velocity pipeline. Requires spliced and unspliced count layers. Parameters ---------- adata : AnnData Must contain layers 'spliced' and 'unspliced'. mode : {'dynamical', 'stochastic', 'steady_state'} n_top_genes : int n_pcs : int n_neighbors : int min_shared_counts : int verbose : bool """ if not _SCVELO_AVAILABLE: raise ImportError("scvelo is required. pip install scvelo") missing = [l for l in ["spliced", "unspliced"] if l not in adata.layers] if missing: raise ValueError( f"Missing required layers: {missing}. " "These are generated by velocyto, STARsolo, or alevin-fry." ) if verbose: print(f"[scCS] Running scVelo pipeline (mode='{mode}')...") scv.pp.filter_and_normalize( adata, min_shared_counts=min_shared_counts, n_top_genes=n_top_genes, log=True, ) if "X_pca" not in adata.obsm and _SCANPY_AVAILABLE: sc.tl.pca(adata, n_comps=n_pcs) if "neighbors" not in adata.uns and _SCANPY_AVAILABLE: sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs) scv.pp.moments(adata, n_pcs=n_pcs, n_neighbors=n_neighbors) if mode == "dynamical": try: scv.tl.recover_dynamics(adata, n_jobs=-1) scv.tl.velocity(adata, mode="dynamical") except Exception as e: warnings.warn( f"Dynamical model failed ({e}). Falling back to stochastic.", RuntimeWarning, stacklevel=2, ) scv.tl.velocity(adata, mode="stochastic") else: scv.tl.velocity(adata, mode=mode) scv.tl.velocity_graph(adata) try: scv.tl.velocity_pseudotime(adata) except Exception: pass if verbose: print("[scCS] Velocity pipeline complete.")
# --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _resolve_metric( adata, metric: Union[str, np.ndarray], invert: bool, ) -> np.ndarray: """Resolve differentiation metric to a per-cell float array.""" n_cells = adata.n_obs if isinstance(metric, np.ndarray): scores = np.asarray(metric, dtype=float).ravel() if len(scores) != n_cells: raise ValueError( f"Custom metric array has length {len(scores)}, " f"expected {n_cells}." ) elif metric == "pseudotime": if "velocity_pseudotime" not in adata.obs: if _SCVELO_AVAILABLE and "velocity_graph" in adata.uns: scv.tl.velocity_pseudotime(adata) else: warnings.warn( "velocity_pseudotime not found and cannot be computed. " "Falling back to uniform scores (random ordering).", RuntimeWarning, stacklevel=3, ) scores = np.random.default_rng(0).uniform(0, 1, n_cells) if invert: scores = 1.0 - scores return _fill_nan(scores) scores = np.array(adata.obs["velocity_pseudotime"], dtype=float) # NOTE: This pseudotime was computed on the full adata. After subsetting, # the caller should invoke compute_local_pseudotime() to get a # subset-local pseudotime with better arm coverage. elif metric == "cytotrace": # CytoTRACE2: look for common column names candidates = ["cytotrace2_score", "CytoTRACE2_Score", "cytotrace_score", "CytoTRACE2", "cytotrace2"] found = None for c in candidates: if c in adata.obs: found = c break if found is None: raise ValueError( "CytoTRACE2 score not found in adata.obs. " f"Expected one of: {candidates}. " "Run CytoTRACE2 first or pass the column name as metric." ) scores = np.array(adata.obs[found], dtype=float) # CytoTRACE2: high score = stem-like = LESS differentiated # So we invert by default unless user explicitly set invert=False # We flip the invert flag here since CytoTRACE2 is naturally inverted invert = not invert else: # Treat as column name in adata.obs if metric not in adata.obs: raise ValueError( f"Column '{metric}' not found in adata.obs. " f"Available columns: {list(adata.obs.columns)}" ) scores = np.array(adata.obs[metric], dtype=float) scores = _fill_nan(scores) if invert: scores = scores.max() - scores return scores def _fill_nan(scores: np.ndarray) -> np.ndarray: """Replace NaN or inf values with a sensible finite value. NaN entries are filled with the finite median. Positive-/negative-inf entries are clipped to the finite max/min so downstream consumers (star embedding arm projection, etc.) never produce NaN positions when fed pseudotime that contains inf from a degenerate diffmap. """ scores = np.asarray(scores, dtype=float).copy() finite = scores[np.isfinite(scores)] if finite.size == 0: # No information at all — return zeros. return np.zeros_like(scores) fin_min, fin_max, fin_med = finite.min(), finite.max(), float(np.median(finite)) pos_inf = np.isposinf(scores) neg_inf = np.isneginf(scores) nan_mask = np.isnan(scores) if pos_inf.any(): scores[pos_inf] = fin_max if neg_inf.any(): scores[neg_inf] = fin_min if nan_mask.any(): scores[nan_mask] = fin_med return scores def _graph_velocity_projection( adata, coords: np.ndarray, verbose: bool = True, ) -> Tuple[np.ndarray, np.ndarray]: """Fallback: project velocity using the velocity graph transition matrix. For each cell i, the velocity vector is the weighted average displacement to its neighbors, weighted by the transition probability T[i, j]: v_i = sum_j T[i,j] * (x_j - x_i) Parameters ---------- adata : AnnData coords : np.ndarray, shape (n_cells, 2) verbose : bool Returns ------- vx, vy : np.ndarray, shape (n_cells,) """ import scipy.sparse as sp if verbose: print("[scCS] Using graph-based velocity projection...") # Try velocity_graph first, then connectivities as fallback T = None for key in ["velocity_graph", "velocity_graph_neg"]: if key in adata.uns: T_raw = adata.uns[key] if sp.issparse(T_raw): T = T_raw else: T = sp.csr_matrix(T_raw) break if T is None: # Last resort: use kNN connectivities if "connectivities" in adata.obsp: T = adata.obsp["connectivities"] if verbose: warnings.warn( "velocity_graph not found. Using kNN connectivities as proxy.", RuntimeWarning, stacklevel=2, ) else: warnings.warn( "No velocity graph or connectivity matrix found. " "Returning zero velocity vectors.", RuntimeWarning, stacklevel=2, ) return np.zeros(adata.n_obs), np.zeros(adata.n_obs) # Row-normalize transition matrix T = T.astype(float) row_sums = np.array(T.sum(axis=1)).ravel() row_sums[row_sums == 0] = 1.0 T_norm = sp.diags(1.0 / row_sums) @ T # Expected position under transition expected_coords = T_norm @ coords # (n_cells, 2) V = expected_coords - coords # displacement = velocity return V[:, 0], V[:, 1] # --------------------------------------------------------------------------- # Subset-local pseudotime recomputation # ---------------------------------------------------------------------------
[docs] def compute_local_pseudotime( adata_sub, adata_full, scale_01: bool = True, verbose: bool = True, ) -> np.ndarray: """Recompute velocity pseudotime on the subset's induced subgraph. When ``build_star_embedding`` uses ``ordering_metric='pseudotime'``, the pseudotime is resolved on the full adata before subsetting. This means the pseudotime range within the bifurcation+fate subset is compressed and non-uniform: cells that span the full differentiation axis in the subset may all cluster near 0 or 1 on the arm, leaving large empty stretches. This function extracts the velocity_graph submatrix for the subset cells, recomputes pseudotime locally, and optionally scales it to [0, 1]. The result is stored in ``adata_sub.obs['sccs_pseudotime']`` and returned as an array. Call this after ``build_embedding()`` and before (or instead of) using the full-adata pseudotime for arm ordering. To rebuild the embedding with the corrected pseudotime, pass the returned array as a custom metric:: scorer.build_embedding(ordering_metric='pseudotime') pt_sub = compute_local_pseudotime(scorer.adata_sub, adata) scorer.build_embedding(ordering_metric=pt_sub_full) # where pt_sub_full is the subset scores mapped back to full adata indices Alternatively, use the convenience method ``SingleScorer.refit_pseudotime()``. Parameters ---------- adata_sub : AnnData Subset returned by ``build_star_embedding()``. Must have ``uns['sccs']['parent_indices']`` set (done automatically). adata_full : AnnData Full dataset with intact ``uns['velocity_graph']``. scale_01 : bool If True (default), min-max scale the recomputed pseudotime to [0, 1] within the subset. This ensures cells span the full arm length regardless of where the subset sits in the global pseudotime range. If False, the raw pseudotime values are returned (useful when you want to compare absolute pseudotime across conditions). verbose : bool Returns ------- pt_sub : np.ndarray, shape (n_sub_cells,) Subset-local pseudotime, stored in ``adata_sub.obs['sccs_pseudotime']``. """ if not _SCVELO_AVAILABLE: raise ImportError( "scvelo is required for pseudotime recomputation. pip install scvelo" ) if "velocity_graph" not in adata_full.uns: raise ValueError( "velocity_graph not found in adata_full.uns. " "Run scvelo.tl.velocity_graph() first." ) import scipy.sparse as sp parent_idx = adata_sub.uns.get("sccs", {}).get("parent_indices", None) if parent_idx is None: # Fall back to obs_names matching sub_names = set(adata_sub.obs_names) full_names = list(adata_full.obs_names) parent_idx = np.array([i for i, n in enumerate(full_names) if n in sub_names]) if verbose: print( f"[scCS] Recomputing pseudotime on subset " f"({len(parent_idx)} / {adata_full.n_obs} cells)..." ) # Extract the sub × sub block of the velocity graph T_full = adata_full.uns["velocity_graph"] if not sp.issparse(T_full): T_full = sp.csr_matrix(T_full) T_sub = T_full[parent_idx, :][:, parent_idx] # (n_sub, n_sub) # Inject the subgraph into a temporary copy of adata_sub for scVelo adata_tmp = adata_sub.copy() adata_tmp.uns["velocity_graph"] = T_sub # scVelo's velocity_pseudotime uses the graph to compute a diffusion-based # ordering. We need neighbors connectivities too; use the subset block. if "connectivities" in adata_full.obsp: C_full = adata_full.obsp["connectivities"] if not sp.issparse(C_full): C_full = sp.csr_matrix(C_full) C_sub = C_full[parent_idx, :][:, parent_idx] adata_tmp.obsp["connectivities"] = C_sub adata_tmp.obsp["distances"] = C_sub # placeholder; scVelo only needs connectivities try: scv.tl.velocity_pseudotime(adata_tmp) pt_sub = np.array(adata_tmp.obs["velocity_pseudotime"], dtype=float) except Exception as e: warnings.warn( f"scvelo.tl.velocity_pseudotime on subset failed ({e}). " "Falling back to diffusion pseudotime via scanpy.", RuntimeWarning, stacklevel=2, ) pt_sub = _fallback_dpt(adata_tmp, verbose=verbose) pt_sub = _fill_nan(pt_sub) if scale_01: pt_min, pt_max = pt_sub.min(), pt_sub.max() if pt_max > pt_min: pt_sub = (pt_sub - pt_min) / (pt_max - pt_min) else: pt_sub = np.zeros_like(pt_sub) if verbose: print("[scCS] Subset pseudotime scaled to [0, 1].") adata_sub.obs["sccs_pseudotime"] = pt_sub if verbose: print( f"[scCS] Subset pseudotime stored in " f"adata_sub.obs['sccs_pseudotime']. " f"Range: [{pt_sub.min():.3f}, {pt_sub.max():.3f}]" ) return pt_sub
[docs] def scale_metric_01(scores: np.ndarray) -> np.ndarray: """Min-max scale a per-cell metric to [0, 1]. Useful for normalizing any differentiation metric (pseudotime, CytoTRACE2, pathway score, etc.) before passing it to ``build_star_embedding`` so that cells span the full arm length uniformly. Parameters ---------- scores : np.ndarray, shape (n_cells,) Per-cell metric values. NaN values are preserved. Returns ------- scaled : np.ndarray, shape (n_cells,) Values in [0, 1]. Returns zeros if all values are identical. """ scores = np.asarray(scores, dtype=float) s_min = np.nanmin(scores) s_max = np.nanmax(scores) if s_max <= s_min: return np.zeros_like(scores) return (scores - s_min) / (s_max - s_min)
def _fallback_dpt(adata_tmp, verbose: bool = True) -> np.ndarray: """Fallback: diffusion pseudotime via scanpy when scVelo fails. Runs ``sc.tl.diffmap`` before ``sc.tl.dpt`` (scanpy emits a warning otherwise and falls back to default-parameter diffmap, which can yield ``inf`` pseudotime for cells in disconnected components of a sparse star-subset graph). Any remaining non-finite values are clipped to the finite range, so the returned array is always finite and the star embedding never produces NaN arm positions. """ if not _SCANPY_AVAILABLE: warnings.warn( "scanpy not available for DPT fallback. Returning radial distance.", RuntimeWarning, stacklevel=2, ) coords = np.array(adata_tmp.obsm["X_sccs"]) return np.linalg.norm(coords, axis=1) try: import scanpy as sc if "connectivities" not in adata_tmp.obsp: sc.pp.neighbors(adata_tmp, n_neighbors=15, use_rep="X_sccs") # Root = cell with smallest scCS radius coords = np.array(adata_tmp.obsm["X_sccs"]) radii = np.linalg.norm(coords, axis=1) root_idx = int(np.argmin(radii)) adata_tmp.uns["iroot"] = root_idx # Prerequisite for DPT: diffusion map components. # Running this explicitly avoids scanpy's default-fallback path, # which on disconnected star-subsets can produce inf pseudotime. try: sc.tl.diffmap(adata_tmp, n_comps=15) except Exception: # Some adatas lack neighbors with enough components; rebuild. sc.pp.neighbors(adata_tmp, n_neighbors=15, use_rep="X_sccs") sc.tl.diffmap(adata_tmp, n_comps=15) sc.tl.dpt(adata_tmp) pt = np.array(adata_tmp.obs["dpt_pseudotime"], dtype=float) # Repair non-finite entries (inf from disconnected diffmap components). if not np.isfinite(pt).all(): finite = pt[np.isfinite(pt)] if finite.size > 0: pt = np.where(np.isfinite(pt), pt, finite.max()) else: # Total failure: fall back to radial distance pt = np.linalg.norm(coords, axis=1) warnings.warn( "DPT produced no finite pseudotime; using radial distance.", RuntimeWarning, stacklevel=2, ) if verbose: print("[scCS] Used scanpy DPT (with diffmap) as pseudotime fallback.") return pt except Exception as e2: warnings.warn( f"DPT fallback also failed ({e2}). Returning radial distance as pseudotime.", RuntimeWarning, stacklevel=2, ) coords = np.array(adata_tmp.obsm["X_sccs"]) return np.linalg.norm(coords, axis=1)