From 5b4287ddb7c93e60aa85c9639b32918d6262ff96 Mon Sep 17 00:00:00 2001 From: Francesco Strino Date: Tue, 10 Feb 2026 13:39:02 +0200 Subject: [PATCH 1/2] first version --- locat/__init__.py | 5 + locat/locat_condensed.py | 1398 +++++++++++++++++++++++++++ locat/plotting_and_other_methods.py | 383 ++++++++ locat/rgmm.py | 207 ++++ locat/wgmm.py | 93 ++ locat/wgmms.py | 141 +++ pyproject.toml | 54 ++ 7 files changed, 2281 insertions(+) create mode 100644 locat/__init__.py create mode 100755 locat/locat_condensed.py create mode 100755 locat/plotting_and_other_methods.py create mode 100755 locat/rgmm.py create mode 100755 locat/wgmm.py create mode 100755 locat/wgmms.py create mode 100644 pyproject.toml diff --git a/locat/__init__.py b/locat/__init__.py new file mode 100644 index 0000000..1629ac4 --- /dev/null +++ b/locat/__init__.py @@ -0,0 +1,5 @@ +# Get the version from _version.py (added when building using scm) +try: + from .version import __version__ # noqa +except ModuleNotFoundError as e: + __version__ = '0.0.0-dev' \ No newline at end of file diff --git a/locat/locat_condensed.py b/locat/locat_condensed.py new file mode 100755 index 0000000..0a7e8f0 --- /dev/null +++ b/locat/locat_condensed.py @@ -0,0 +1,1398 @@ +import math + +import numba +import numpy as np +from loguru import logger +from scanpy import AnnData +from scipy.interpolate import PchipInterpolator +from scipy.special import logsumexp +from scipy.stats import binom, betabinom, chi2 +from sklearn.metrics.pairwise import euclidean_distances as edist +from tqdm import tqdm +from sklearn.neighbors import NearestNeighbors +from scipy.stats import norm +from locat.wgmm import WGMM +from locat.wgmms import wgmm +from locat.rgmm import softbootstrap_gmm + +class LOCATNullDistribution: + mean = None + std = None + + def __init__(self, mean_func, std_func): + self.mean = mean_func + self.std = std_func + + def to_zscore(self, raw_score, p): + return (raw_score - self.mean(p)) / self.std(p) + + @classmethod + def from_estimates(cls, p, means, stds): + return cls( + mean_func=PchipInterpolator(p, means), + std_func=PchipInterpolator(p, stds), + ) + + +class LOCAT: + _cell_dist = None + _min_dist = None + _knn = None + _background_gmm: WGMM = None + _background_pdf: np.ndarray = None + _background_logpdf: np.ndarray = None + _null_distribution: LOCATNullDistribution = None + _X = None + _n_components_waypoints = None + _disable_progress_info = True + + def __init__( + self, + adata: AnnData, + cell_embedding: np.ndarray, + k: int, + n_bootstrap_inits: int, + show_progress=False, + wgmm_dtype: str = "same", # "same" | "float32" | "float64" + knn=None, # <-- NEW: precomputed adjacency/connectivities + knn_k: int | None = None, # <-- NEW: if computing, how many neighbors (defaults to k) + knn_mode: str = "binary", # <-- NEW: "binary" or "connectivity" + ): + self._disable_progress_info = not show_progress + self._adata = adata + + emb = np.asarray(cell_embedding) + emb = (emb - emb.mean(0)) / (emb.std(0) + emb.dtype.type(1e-6)) + self._embedding = emb + self._dtype = self._embedding.dtype + self._wgmm_dtype = wgmm_dtype + + self._k = k + self.n_cells, self.n_genes = self._adata.shape + self.n_dims = self._embedding.shape[1] + self.n_bootstrap_inits = n_bootstrap_inits + self._knn = None + self._knn_k = int(knn_k) if knn_k is not None else int(k) + self._knn_mode = str(knn_mode) + + self._reg_covar = None + if knn is not None: + self.set_knn(knn) + # ------------------------------------------------------------------ + # Basic geometry / regularization + # ------------------------------------------------------------------ + @property + def cell_dist(self) -> np.ndarray: + if self._cell_dist is None: + if not self._disable_progress_info: + logger.info("recomputing cell-cell distance") + self._cell_dist = edist(self._embedding) + return self._cell_dist + + @property + def min_dist(self) -> float: + if self._min_dist is None: + if not self._disable_progress_info: + logger.info("recomputing min cell-cell distance") + self._min_dist = np.mean( + np.nanmin( + np.where(self.cell_dist > 0, self.cell_dist, np.nan), + axis=1, + ) + ) + return self._min_dist + + def reg_covar(self, sample_size=None): + if self._reg_covar is not None: + return self._dtype.type(self._reg_covar) + + base = (self.min_dist ** (2 / self.n_dims)) / 6 + if sample_size is None: + rc = base + else: + c = 1 - 2 / self.n_dims + adj = np.sqrt((self.n_cells + c) / (sample_size + c)) + rc = base * adj + + rc = max(rc, 1e-4) + return self._dtype.type(rc) + + # ------------------------------------------------------------------ + # Data access / options + # ------------------------------------------------------------------ + @property + def X(self): + if self._X is None: + X = self._adata.X + if not isinstance(X, np.ndarray): + X = X.toarray() + self._X = X.astype(self._dtype, copy=False) + return self._X + + def show_progress(self, show_progress=True): + self._disable_progress_info = not show_progress + + # ------------------------------------------------------------------ + # Background GMM and LTST null + # ------------------------------------------------------------------ + def background_n_components_init(self, weights_transform=None, min_points=10, n_reps=30): + if not self._disable_progress_info: + logger.info("Estimating number of GMM components") + + self._n_components_waypoints = np.zeros(shape=(3, 2)) + self._n_components_waypoints[0, 0] = min_points + self._n_components_waypoints[1, 0] = int(np.sqrt(self.n_cells)) + self._n_components_waypoints[2, 0] = self.n_cells + + self._n_components_waypoints[0, 1] = 1 + + bic_component_cost = self.n_dims * (self.n_dims + 3) / 2 + + if weights_transform is not None: + Xdense = self.X.copy() + for i in range(self.n_genes): + Xdense[:, i] = weights_transform(Xdense[:, i]) + weights = np.asarray(Xdense.sum(axis=1), dtype=self._dtype) + else: + weights = np.ones((self.n_cells,), dtype=self._dtype) + + weights = np.clip(weights, self._dtype.type(1e-6), np.inf).astype( + self._dtype, copy=False + ) + s = float(weights.sum()) + weights = (weights / s) if s > 0 else np.full_like(weights, 1.0 / len(weights)) + weights[~np.isfinite(weights)] = self._dtype.type(1.0 / len(weights)) + + # Full-data BIC search for big-N waypoint + min_search = 1 + max_search = 10 + keep_searching = True + bic = [] + while keep_searching: + for i in tqdm( + range(min_search, max_search), + desc=f"estimating BIC for {self._n_components_waypoints[2, 0]:.0f} cells", + position=0, + leave=True, + disable=self._disable_progress_info, + ): + p = self.fit_wgmm(n_comp=i, weights=weights).pdf(self._embedding) + bic.append( + bic_component_cost * i * np.log(self.n_cells) + - np.sum(2 * np.log(p[p > 0])) + ) + self._n_components_waypoints[2, 1] = np.argmin(bic) + 1 + if self._n_components_waypoints[2, 1] >= len(bic) - 2: + min_search = max_search + max_search = max_search + 5 + else: + keep_searching = False + + # sqrt(n_cells) waypoint + n = np.max([min_points + 1, int(np.ceil(np.sqrt(self.n_cells)))]) + bic = np.zeros(shape=(int(self._n_components_waypoints[2, 1]), n_reps)) + + o = np.arange(self.n_cells) + wgs = np.zeros(shape=(self.n_cells, n_reps)) + for i_rep in range(n_reps): + np.random.shuffle(o) + wgs[o < n, i_rep] = weights[o < n] + wgs = wgs / np.sum(wgs, axis=0) + + for i_rep in tqdm( + range(n_reps), + desc=f"estimating BIC for {self._n_components_waypoints[1, 0]:.0f} cells", + position=0, + leave=True, + disable=self._disable_progress_info, + ): + for i in range(1, bic.shape[0] + 1): + p = self.fit_wgmm(n_comp=i, weights=wgs[:, i_rep]).pdf(self._embedding) + bic[i - 1, i_rep] = ( + bic_component_cost * i * np.log(n) - np.sum(2 * np.log(p[p > 0])) + ) + self._n_components_waypoints[1, 1] = np.argmin(np.median(bic, axis=1)) + 1 + + def auto_bkg_components(self, n_points, weights_transform=None): + if self._n_components_waypoints is None: + self.background_n_components_init(weights_transform=weights_transform) + + n_components = np.interp( + np.log(n_points), + xp=np.log(self._n_components_waypoints[:, 0]), + fp=np.log(self._n_components_waypoints[:, 1]), + ) + return int(np.round(np.exp(n_components))) + + def _auto_n_effective_weights(self, weights=None, min_dist_cutoff=None): + if min_dist_cutoff is None: + min_dist_cutoff = self.min_dist * 3 + + rweights = np.zeros(shape=(self.n_cells,)) + if weights is None: + wg0 = rweights == 0 + else: + wg0 = weights > 0 + + if np.sum(wg0) > 5: + # only compute the components to describe the cells that are close enough + cell_include = ( + np.min(self.cell_dist[wg0, :][:, wg0] + np.eye(np.sum(wg0)) * 10, axis=1) + < min_dist_cutoff + ) + rweights[wg0] = cell_include.astype(float) + return rweights + return None + + def auto_n_components(self, coords, weights=None, indices=None, min_points_fraction=0.95): + if weights is None: + weights = np.ones(shape=(coords.shape[0],)) + + if indices is None: + if coords.shape[0] != self.n_cells: + raise ValueError("A selection was made but the indices were not passed") + indices = np.full_like(weights, fill_value=True) + + logpdf = self._background_logpdf[indices, :] * weights[:, None] + id_component = np.argmax(logpdf, axis=1) + + component_counts = np.bincount( + id_component[weights > 0], minlength=self._background_gmm.n_comp + ) + n_counts = np.sum(component_counts) + component_counts[component_counts < 5] = 0 + if np.sum(component_counts) < 10: + return 1 + else: + component_counts = component_counts / np.sum(component_counts) + estimated_n_components = ( + np.min( + np.flatnonzero( + np.cumsum(np.sort(component_counts)[::-1]) > min_points_fraction + ) + ) + + 1 + ) + estimated_n_components = int(np.floor(np.sqrt(estimated_n_components + 1)) + 1) + return int( + np.max( + [ + 1, + np.min( + [ + self._background_gmm.n_comp, + n_counts / 5, + estimated_n_components, + ] + ), + ] + ) + ) + def set_knn(self, knn): + """ + Store a precomputed KNN graph. + + Accepts: + - scipy sparse (csr/csc/coo) adjacency or connectivities + - dense numpy array adjacency/connectivities + + Expected shape: (n_cells, n_cells) + """ + if sp.issparse(knn): + K = knn.tocsr() + else: + K = np.asarray(knn) + + if K.shape != (self.n_cells, self.n_cells): + raise ValueError(f"knn must have shape {(self.n_cells, self.n_cells)}, got {K.shape}") + + # zero diagonal to avoid self-neighbor artifacts + if sp.issparse(K): + K = K.tolil() + K.setdiag(0.0) + K = K.tocsr() + K.eliminate_zeros() + else: + np.fill_diagonal(K, 0.0) + + self._knn = K + + + def knn(self): + """ + Return a KNN adjacency/connectivity matrix. + + If a KNN was provided at init (or via set_knn), returns it. + Otherwise computes one from the embedding. + """ + if self._knn is not None: + return self._knn + + k = int(self._knn_k) + if k <= 0: + raise ValueError("knn_k must be >= 1") + + # Build kNN from embedding (exclude self by taking k+1 then dropping self) + nbrs = NearestNeighbors(n_neighbors=min(k + 1, self.n_cells), metric="euclidean") + nbrs.fit(self._embedding) + dists, inds = nbrs.kneighbors(self._embedding, return_distance=True) + + # Drop self neighbor (usually first column) + inds = inds[:, 1: k + 1] + + rows = np.repeat(np.arange(self.n_cells), inds.shape[1]) + cols = inds.reshape(-1) + + if self._knn_mode == "binary": + data = np.ones_like(cols, dtype=self._dtype) + elif self._knn_mode == "connectivity": + # simple distance-based weights; you can swap for something else + dd = dists[:, 1: k + 1].reshape(-1) + data = (1.0 / (dd + 1e-6)).astype(self._dtype, copy=False) + else: + raise ValueError("knn_mode must be 'binary' or 'connectivity'") + + K = sp.csr_matrix((data, (rows, cols)), shape=(self.n_cells, self.n_cells)) + K.eliminate_zeros() + self._knn = K + return self._knn + + + def background_pdf( + self, + n_comp=None, + reps=10, + total_counts_weight=True, + weights_transform=None, + force_refresh=False, + ): + if (self._background_pdf is None) or force_refresh: + if not self._disable_progress_info: + logger.info("fitting background PDF") + + if n_comp is None: + if self._n_components_waypoints is None: + self.background_n_components_init(weights_transform=weights_transform) + n_comp = np.interp( + np.log(self.n_cells), + xp=np.log(self._n_components_waypoints[:, 0]), + fp=np.log(self._n_components_waypoints[:, 1]), + ) + n_comp = int(np.round(np.exp(n_comp))) + if not self._disable_progress_info: + logger.info(f"Using {n_comp} components") + + if weights_transform is not None: + Xdense = self.X.copy() + for i in range(self.n_genes): + Xdense[:, i] = weights_transform(Xdense[:, i]) + weights = np.asarray(Xdense.sum(axis=1), dtype=self._dtype) + else: + weights = np.ones((self.n_cells,), dtype=self._dtype) + + weights = np.clip(weights, self._dtype.type(1e-6), np.inf).astype( + self._dtype, copy=False + ) + s = float(weights.sum()) + weights = (weights / s) if s > 0 else np.full_like(weights, 1.0 / len(weights)) + weights[~np.isfinite(weights)] = self._dtype.type(1.0 / len(weights)) + + self._background_pdf = np.zeros(shape=(self.n_cells,)) + self._background_logpdf = np.zeros(shape=(self.n_cells, n_comp)) + + background_gmm = None + for _ in tqdm( + range(reps), + desc="fitting background", + position=0, + leave=True, + disable=self._disable_progress_info, + ): + background_gmm = self.fit_wgmm(n_comp, weights=weights) + self._background_pdf += background_gmm.pdf(self._embedding) + self._background_logpdf += background_gmm.loglikelihood_by_component( + self._embedding, np.ones(shape=self.n_cells) + ) + self._background_gmm = background_gmm + self._background_pdf /= reps + self._background_logpdf /= reps + + self.estimate_null_parameters() + return self._background_pdf + + def signal_gmm(self, weights, n_comp=None): + if n_comp is None: + comp_weights = self._auto_n_effective_weights(weights) + if comp_weights is None: + n_comp = 1 + else: + loc_indices = comp_weights > 0 + n_comp = self.auto_n_components( + self._embedding[loc_indices, :], + weights[loc_indices], + indices=loc_indices, + ) + + return self.fit_wgmm(n_comp, weights=weights) + + def fit_wgmm(self, n_comp, weights=None) -> WGMM: + if weights is None: + weights = np.ones(shape=(self.n_cells,)) + + pis, mus, sigmas, _ = wgmm( + self._embedding, + raw_weights=weights, + n_components=n_comp, + n_inits=1, + reg_covar=self.reg_covar(), + ) + pis = np.array(pis) + mus = np.array(mus) + sigmas = np.array(sigmas) + return WGMM(pis, mus, sigmas) + + def estimate_null_parameters(self, fractions=None, n_reps=50): + """ + Estimate LTST null mean/std as a function of expression fraction p + using random pseudo-genes: + - pick random expressing cells at frequency p + - fit signal GMM with the *same* pipeline as real genes + - compute LTST exactly as in gmm_scan_new + """ + if fractions is None: + fractions = 10 ** np.linspace(np.log10(10 / self.n_cells), 0, 7) + + f0 = self.background_pdf() + scores = [] + + for frac in tqdm( + fractions, + desc="null distribution parameters (perm. pseudo-genes)", + position=0, + leave=True, + disable=self._disable_progress_info, + ): + n_pos = max(5, int(round(frac * self.n_cells))) + if n_pos >= self.n_cells: + n_pos = self.n_cells - 1 + + ltst_vals = [] + + for _ in range(n_reps): + mask = np.zeros(self.n_cells, dtype=bool) + mask[np.random.choice(self.n_cells, n_pos, replace=False)] = True + gene_prior = mask.astype(self._dtype) + + comp_gene_prior = self._auto_n_effective_weights(gene_prior) + if comp_gene_prior is None: + n_comp = 1 + else: + loc_indices = comp_gene_prior > 0 + n_comp = self.auto_n_components( + self._embedding[loc_indices, :], + gene_prior[loc_indices], + loc_indices, + ) + gmm1 = self.signal_gmm(weights=gene_prior, n_comp=n_comp) + + i1 = gene_prior > 0 + loc_f1 = gmm1.pdf(self._embedding[i1, :]) + p1 = float(np.mean(i1)) + + ltst_score = ltst_score_func(f0[i1], loc_f1, p1) + + w_expr = gene_prior[i1] + w_expr = w_expr / (w_expr.sum() if w_expr.sum() > 0 else 1.0) + ltst_vals.append(float(np.dot(w_expr, ltst_score))) + + ltst_vals = np.asarray(ltst_vals, dtype=float) + scores.append( + [ + frac, + float(np.mean(ltst_vals)), + float(np.std(ltst_vals) + 1e-9), + ] + ) + + scores = np.asarray(scores) + self._null_distribution = LOCATNullDistribution.from_estimates( + p=scores[:, 0], + means=scores[:, 1], + stds=scores[:, 2], + ) + + # ------------------------------------------------------------------ + # Depletion-style localization scan + # ------------------------------------------------------------------ + def localization_pval_dep_scan( + self, + gmm1, + gene_prior, + *, + c_values=None, + soft_bound=None, # default computed from n: max((n-1)/n, 0.99) + min_p0_abs=0.10, + min_expected=30, + min_abs_deficit=0.02, + n_trials_cap=500, + weight_mode="amount", + p_floor=1e-12, + n_eff_scale=1.0, + rho_bb=0.2, # >0 enables Beta–Binomial tail + eps_rel=0.01, + debug=False, + debug_store_masks=False, + debug_max_cells=5000, + ): + if self._background_gmm is None: + _ = self.background_pdf() + + X = self._embedding + n = int(X.shape[0]) + if soft_bound is None: + soft_bound = max((n - 1) / max(n, 1), 0.99) + + gp = np.asarray(gene_prior, float) + if weight_mode == "binary": + w_obs = (gp > 0).astype(float) + elif weight_mode == "amount": + w_obs = np.clip(gp, 0.0, np.inf) + else: + raise ValueError("weight_mode must be 'amount' or 'binary'") + + # Kish n_eff with tempering + cap + sw = float(np.sum(w_obs)) + n_eff = 0.0 if sw <= 0 else (sw * sw) / max(float(np.sum(w_obs * w_obs)), 1e-12) + n_eff *= float(n_eff_scale) + if n_trials_cap is not None: + n_eff = min(n_eff, float(n_trials_cap)) + n_eff = max(1.0, n_eff) + n_trials = int(round(n_eff)) + + f0_x = np.clip(self._background_gmm.pdf(X), 1e-300, np.inf) + f1_x = np.clip(gmm1.pdf(X), 1e-300, np.inf) + w0 = f0_x / float(np.sum(f0_x)) + + w_obs_alpha = w_obs / (sw if sw > 0 else 1.0) + w0_alpha = w0 + + if c_values is None: + c_values = np.concatenate([[1.0], np.geomspace(1.05, 3.0, 12)]) + + best_logp = None + best = {"c": None, "k_obs": None, "p0": None, "obs_prop": None} + scanned = 0 + tested = 0 + + per_c = [] if debug else None + + for c in c_values: + c = float(c) + in_R_mask = f0_x > c * f1_x + + p0_abs = float(np.sum(w0_alpha * in_R_mask)) + reason = None + if p0_abs < min_p0_abs: + reason = "fail:min_p0_abs" + if n_eff * p0_abs < min_expected: + reason = "fail:min_expected" + if p0_abs > soft_bound: + reason = "fail:soft_bound" + + obs_prop = float(np.sum(w_obs_alpha * in_R_mask)) + if reason is None: + if (p0_abs - obs_prop) < min_abs_deficit: + reason = "fail:min_abs_deficit" + elif obs_prop > (p0_abs / c) * (1.0 - float(eps_rel)): + reason = "fail:c_bound" #this checks whether observed f1 density in the region is at an equal or lower proportion than the observed f0 density in the region * c (where c is the contrast). If the region is larger than expectation, reject the gene. + + if debug: + rec = { + "c": c, + "p0_abs": p0_abs, + "obs_prop": obs_prop, + "n_eff": float(n_eff), + "n_eff_expected": float(n_eff * p0_abs), + "reason": reason, + } + if debug_store_masks: + ncap = min(int(debug_max_cells), n) + rec["in_R_idx"] = np.flatnonzero(in_R_mask[:ncap]).astype(int) + if weight_mode == "binary": + expr_mask = gp > 0 + rec["expr_in_R_count"] = int(np.sum(expr_mask & in_R_mask)) + per_c.append(rec) + + if reason is not None: + continue + + scanned += 1 + tested += 1 + + k_eff = int(np.rint(obs_prop * n_trials)) + p0_clip = np.clip(p0_abs, 1e-12, 1 - 1e-12) + + if rho_bb and rho_bb > 0.0: + ab_sum = max(1.0 / float(rho_bb) - 1.0, 2.0) + alpha = float(p0_clip * ab_sum) + beta = float((1.0 - p0_clip) * ab_sum) + p_raw = float(betabinom.cdf(k_eff, n_trials, alpha, beta)) + logp_raw = np.log(max(p_raw, np.finfo(float).tiny)) + else: + lFkm1 = binom.logcdf(k_eff - 1, n_trials, p0_clip) + lpk = binom.logpmf(k_eff, n_trials, p0_clip) + np.log(0.5) + logp_raw = logsumexp([lFkm1, lpk]) + + if (best_logp is None) or (logp_raw < best_logp): + best_logp = float(logp_raw) + best.update( + {"c": c, "k_obs": k_eff, "p0": p0_abs, "obs_prop": obs_prop} + ) + + if tested == 0: + out = { + "p_value": 1.0, + "raw_min_p": 1.0, + "log_p_single": 0.0, + "log_p_sidak": 0.0, + "neglog10_p_single": 0.0, + "neglog10_p_sidak": 0.0, + "best_c": None, + "k_obs_eff": None, + "p0_abs": None, + "obs_prop": None, + "scanned": int(scanned), + "tested": int(tested), + "sidak_penalty": 1, + "n": n, + "n_eff": float(n_eff), + "guards": { + "min_p0_abs": float(min_p0_abs), + "min_expected": float(min_expected), + "min_abs_deficit": float(min_abs_deficit), + "soft_bound": float(soft_bound), + "eps_rel": float(eps_rel), + }, + "n_eff_scale": float(n_eff_scale), + "rho_bb": float(rho_bb), + } + if debug: + out["per_c"] = per_c + return out + + m_eff = tested + sidak_logp = _logsidak_from_logp(best_logp, m_eff) + + p_value = float(np.exp(sidak_logp)) + raw_min_p = float(np.exp(best_logp)) + p_value = _safe_p(max(p_value, p_floor)) + raw_min_p = _safe_p(max(raw_min_p, p_floor)) + + out = { + "p_value": p_value, + "raw_min_p": raw_min_p, + "log_p_single": float(best_logp), + "log_p_sidak": float(sidak_logp), + "neglog10_p_single": float(-best_logp / np.log(10)), + "neglog10_p_sidak": float(-sidak_logp / np.log(10)), + "best_c": best["c"], + "k_obs_eff": best["k_obs"], + "p0_abs": best["p0"], + "obs_prop": best["obs_prop"], + "scanned": int(scanned), + "tested": int(tested), + "sidak_penalty": 1, + "n": n, + "n_eff": float(n_eff), + "guards": { + "min_p0_abs": float(min_p0_abs), + "min_expected": float(min_expected), + "min_abs_deficit": float(min_abs_deficit), + "soft_bound": float(soft_bound), + "eps_rel": float(eps_rel), + }, + "n_eff_scale": float(n_eff_scale), + "rho_bb": float(rho_bb), + } + if debug: + out["per_c"] = per_c + return out + + # ------------------------------------------------------------------ + # Main scan used in practice + # ------------------------------------------------------------------ + def bic_score(self, gmm1, gene_prior): + bic_component_cost = self.n_dims * (self.n_dims + 3) / 2 + p = gmm1.pdf(self._embedding[gene_prior > 0, :]) + n_cells = np.sum(gene_prior > 0) + bic = ( + bic_component_cost * gmm1.n_comp * np.log(np.sum(gene_prior > 0)) + - 2 * np.sum(p[p > 0]) + ) + return bic / n_cells + + def gmm_scan_new( + self, + genes=None, + weights_transform=None, + zscore_thresh=None, + max_freq=0.5, + verbose=False, + n_bootstrap_inits=None, + # Depletion-scan defaults + rc_c_values=None, # default inside method + rc_min_p0_abs=0.10, #minimum proportion of f0 density in depleted region required for the region pval to be estimated + rc_min_expected=30, #minimum expected cells in depleted region required for the region pval to be estimated + rc_min_abs_deficit=0.02, #minimum absolute difference in f1(x) - f0(x) for all x in depleted region + rc_n_trials_cap=500, #maximum effective sample size + rc_soft_bound=1.0, #this is unused/can be removed + rc_n_eff_scale=0.5, #scaling factor for effective sample sizes -- can be tweaked to stabilize pvalues across various gene sample sizes + rc_p_floor=1e-12, # this is just model precision, can be ignored + rc_rho_bb=0.02, #this is the strength of the beta binomial (0.0 is standard binomial, set at 0.02-0.05 for wider tails) + rc_weight_mode="amount", + rc_eps_rel=0.01, + ): + if verbose: + print("gmm_scan_new: using depletion scan for localization_pval (localization_pval_dep_scan)") + + if n_bootstrap_inits is not None: + self.n_bootstrap_inits = int(n_bootstrap_inits) + + locally_enriched = dict() + gzeros, freqzeros, zzeros = [], [], [] + inclgenes = self.get_genes_indices(genes) + + f0 = self.background_pdf(weights_transform=weights_transform) + for i_gene in tqdm( + inclgenes, + desc="scanning genes", + position=0, + leave=True, + disable=self._disable_progress_info, + ): + gene_prior = self.background_pdf(i_gene, weights_transform) + try: + if np.sum(gene_prior) == 0: + gzeros.append(self._adata.var_names[i_gene]) + continue + if np.mean(gene_prior > 0) > max_freq: + freqzeros.append(self._adata.var_names[i_gene]) + continue + + comp_gene_prior = self._auto_n_effective_weights(gene_prior) + if comp_gene_prior is None: + n_comp = 1 + else: + loc_indices = comp_gene_prior > 0 + n_comp = self.auto_n_components( + self._embedding[loc_indices, :], + gene_prior[loc_indices], + loc_indices, + ) + gmm1 = self.signal_gmm(weights=gene_prior, n_comp=n_comp) + + i1 = gene_prior > 0 + loc_f1 = gmm1.pdf(self._embedding[i1, :]) + p1 = np.mean(i1) + sample_size = p1 * self.n_cells + + ltst_score = ltst_score_func(f0[i1], loc_f1, p1) + sens_score = sens_score_func(f0[i1], loc_f1, i1[i1]) + + zscore = np.dot(gene_prior[i1], ltst_score) / np.sum(gene_prior[i1]) + zscore = self._null_distribution.to_zscore(zscore, p1) + if (zscore_thresh is not None) and (zscore < zscore_thresh): + zzeros.append(self._adata.var_names[i_gene]) + continue + + cs_res = self.localization_pval_dep_scan( + gmm1, + gene_prior, + debug=True, + c_values=( + rc_c_values + if rc_c_values is not None + else np.concatenate([[1.0], np.geomspace(1.05, 3.0, 12)]) + ), + soft_bound=rc_soft_bound, + min_p0_abs=rc_min_p0_abs, + min_expected=rc_min_expected, + min_abs_deficit=rc_min_abs_deficit, + n_trials_cap=rc_n_trials_cap, + weight_mode=rc_weight_mode, + p_floor=rc_p_floor, + n_eff_scale=rc_n_eff_scale, + rho_bb=rc_rho_bb, + eps_rel=rc_eps_rel, + ) + localization_pval = _safe_p(cs_res["p_value"]) + + concentration_pval = _safe_p(float(normal_sf(zscore, 0.0, 1.0))) + + p_cauchy = cauchy_combine([localization_pval, concentration_pval]) + p_size = _safe_p(1.0 - np.exp(-1.0 / (sample_size + 1.0))) + p_sens = _safe_p(1.0 - (sens_score + 1e-9)) + p_final = 1.0 - (1.0 - p_cauchy) * (1.0 - 0.15 * p_size) * ( + 1.0 - 0.35 * p_sens + ) + p_final = float(smooth_qvals(np.array([_safe_p(p_final)]))[0]) + + locally_enriched[self._adata.var_names[i_gene]] = { + "bic": self.bic_score(gmm1, gene_prior), + "zscore": zscore, + "sens_score": sens_score, + "localization_pval": localization_pval, + "concentration_pval": concentration_pval, + "pval": p_final, + "n_components": n_comp, + "sample_size": sample_size, + "depl_scan": cs_res, + } + + except ValueError as e: + if verbose: + logger.info(e) + + if verbose: + print("gzeros:", len(gzeros), "freqzeros:", len(freqzeros), "zzeros:", len(zzeros)) + return locally_enriched + + # ------------------------------------------------------------------ + # Small helpers still used by gmm_scan_new + # ------------------------------------------------------------------ + def get_genes_indices(self, genes): + inclgenes = range(self.n_genes) + list_genes = self._adata.var_names.tolist() + if genes is not None: + inclgenes = [list_genes.index(i) for i in genes] + return inclgenes + + def get_gene_prior(self, i_gene, weights_transform): + gene_prior = self.X[:, i_gene] + if weights_transform is not None: + gene_prior = weights_transform(gene_prior) + return gene_prior + + def signal_pdf(self, weights, n_comp=None): + gmm = self.signal_gmm(weights=weights, n_comp=n_comp, ) + return gmm.pdf(self._embedding) + + def gmm_loglikelihoodtest(self, genes=None, weights_transform=None, max_freq=0.5): + res = dict() + bkg_df = self.auto_bkg_components(self.n_cells) + log_bkg_pdf = self.background_pdf(weights_transform=weights_transform) + bkg_pdf_gt0 = log_bkg_pdf > 0 + log_bkg_pdf = np.where(bkg_pdf_gt0, np.log(log_bkg_pdf), 0) + + inclgenes = self.get_genes_indices(genes) + + for i_gene in tqdm(inclgenes, desc=f'scanning genes', + position=0, leave=True, disable=self._disable_progress_info): + gene_prior = self.get_gene_prior(i_gene, weights_transform) + gene_prior_gt0 = gene_prior > 0 + n_gene_prior = np.sum(gene_prior_gt0) + pv = 1.0 + df = None + lr = None + n_comp = 0 + sample_size = None + + if (n_gene_prior > 5) & ((n_gene_prior / len(gene_prior)) < max_freq): + comp_gene_prior = self._auto_n_effective_weights(gene_prior) + if comp_gene_prior is None: + n_comp = 1 + else: + loc_indices = comp_gene_prior > 0 + n_comp = self.auto_n_components( + self._embedding[loc_indices, :], + gene_prior[loc_indices], + indices=loc_indices, # <-- required + ) + + + f1 = self.signal_pdf(weights=gene_prior, n_comp=n_comp) + + df = (bkg_df - n_comp) + 1 + ix = gene_prior_gt0 & (f1 > 0) + if np.sum(ix) > 5: + ix = ix & bkg_pdf_gt0 + sample_size = np.sum(ix) + + if sample_size == 0: + pv = 0 + + elif sample_size > 5: + lr = self.calc_lratio(f1, ix, log_bkg_pdf, sample_size) + pv = 1 - chi2.cdf(lr, df) + + res[self._adata.var_names[i_gene]] = { + 'llratio_pvalue': pv, + 'llratio_sample_size': sample_size, + 'llratio_stat': lr, + 'llratio_df': df, + 'n_comp': n_comp, + } + + return res + + def gmm_local_pvalue( + self, + genes=None, + n_comp=None, + weights_transform=None, + alpha=0.05, + n_inits=100, + normalize_knn=True, + eps=1e-12, + ): + from scipy.stats import rankdata + from sklearn.metrics import roc_curve + import scipy.sparse as sp + import numpy as np + + f0 = np.asarray(self.background_pdf(weights_transform=weights_transform), dtype=self._dtype).ravel() + K = self.knn() + + # Optional: row-normalize K so every cell has comparable neighborhood "mass" + if normalize_knn: + if sp.issparse(K): + deg = np.asarray(K.sum(axis=1)).ravel() + inv = 1.0 / np.maximum(deg, eps) + Kuse = sp.diags(inv).dot(K) + else: + deg = K.sum(axis=1) + Kuse = K / np.maximum(deg[:, None], eps) + else: + Kuse = K + + locally_enriched = {} + inclgenes = self.get_genes_indices(genes) + + for i_gene in tqdm( + inclgenes, + desc="scanning genes", + position=0, + leave=True, + disable=self._disable_progress_info, + ): + gene_prior = self.get_gene_prior(i_gene, weights_transform) + gp = np.asarray(gene_prior, dtype=self._dtype).ravel() + + sw = float(gp.sum()) + if sw <= 0: + continue # or store null result if you prefer + + # signal pdf + LTST + f1 = np.asarray(self.signal_pdf(weights=gp, n_comp=n_comp), dtype=self._dtype).ravel() + p = float(np.mean(gp > 0)) + ltst_score = ltst_score_func(f0, f1, p) # (n_cells,) + + zscore = float(np.dot(gp, ltst_score) / sw) + + # random pseudo-genes (must return f2: (n_cells, n_inits), rweights: (n_cells, n_inits)) + f2, rweights = self.random_pdf(weights=gp, n_comp=n_comp, n_inits=n_inits) + f2 = np.asarray(f2, dtype=self._dtype) + rweights = np.asarray(rweights, dtype=self._dtype) + + # p2 is per-init expression fraction + p2 = np.mean(rweights > 0, axis=0).astype(self._dtype, copy=False) # (n_inits,) + + ltst_score2 = ltst_score_func(f0[:, None], f2, p2) # (n_cells, n_inits) + + # empirical zscore p-value (avoid 0/1 extremes) + z2 = (np.sum(rweights * ltst_score2, axis=0) / sw).astype(np.float64) # (n_inits,) + z_p = (1.0 + np.sum(z2 >= zscore)) / (n_inits + 1.0) + + # neighborhood-smoothed statistics + if sp.issparse(Kuse): + wstat1 = np.asarray(Kuse.dot(ltst_score)).ravel() + wstat2 = np.asarray(Kuse.dot(ltst_score2)) + else: + wstat1 = (Kuse @ ltst_score).ravel() + wstat2 = (Kuse @ ltst_score2) + + # empirical p-values per cell: rank among [wstat1, wstat2...] + # shape: (n_cells, n_inits+1) + M = np.concatenate([wstat1[:, None], wstat2], axis=1) + ranks = rankdata(-M, axis=1, method="average")[:, 0] # rank of observed in each row + emp_p = ranks / (n_inits + 1.0) # in (0,1] + + # pick a cutoff from ROC if there is signal + empirical_h1 = emp_p < (alpha / self.n_cells) + + wstat1_cutoff = 0.0 + if np.any(empirical_h1) and np.any(~empirical_h1): + fpr, tpr, cuts = roc_curve(empirical_h1.astype(int), wstat1) + # maximize balanced accuracy ( (tpr + (1-fpr)) / 2 ) + j = np.argmax(((1 - fpr) + tpr) / 2.0) + wstat1_cutoff = float(cuts[j]) + + local_res = { + "wstat_cutoff": wstat1_cutoff, + "wstat_alpha": float(alpha), + "wstat_pvalues": emp_p.astype(np.float32, copy=False), + "zscore": float(zscore), + "zscore_pvalue": float(z_p), + "wstat_repetitions": int(n_inits), + } + + sig_idx = np.flatnonzero(wstat1 > wstat1_cutoff) + local_res["wstat_significant"] = sig_idx if sig_idx.size else None + local_res["wstat_significant_clusters"] = None # placeholder + + locally_enriched[self._adata.var_names[i_gene]] = local_res + + return locally_enriched + + + def gmm_local_scan(self, + genes=None, + weights_transform=None, + zscore_thresh=None, + max_freq=0.5): + + locally_enriched = dict() + if zscore_thresh is None: + zscore_thresh = 1.0 + K = self.knn() + + if sp.issparse(K): + knn_neis = np.asarray(K.sum(axis=1)).ravel() + inv = 1.0 / np.maximum(knn_neis, 1e-12) + # row-normalize: D^{-1} K + knn_norm = sp.diags(inv).dot(K) + else: + knn_neis = K.sum(axis=1) + knn_norm = K / np.maximum(knn_neis[:, None], 1e-12) + + inclgenes = self.get_genes_indices(genes) + + f0 = self.background_pdf(weights_transform=weights_transform) + f0 = np.asarray(f0, dtype=self._dtype) + for i_gene in tqdm(inclgenes, desc=f'scanning genes', + position=0, leave=True, disable=self._disable_progress_info): + gene_prior = self.get_gene_prior(i_gene, weights_transform) + try: + if np.sum(gene_prior) == 0: + continue + + if np.mean(gene_prior > 0) > max_freq: + continue + + comp_gene_prior = self._auto_n_effective_weights(gene_prior) + if comp_gene_prior is None: + n_comp = 1 + else: + loc_indices = comp_gene_prior > 0 + n_comp = self.auto_n_components( + self._embedding[loc_indices, :], + gene_prior[loc_indices], + indices=loc_indices, # <-- THIS is the missing piece + ) + + + f1 = self.signal_pdf(weights=gene_prior, n_comp=n_comp) + expr = (gene_prior > 0).astype(self._dtype, copy=False) + + # p1 = local expression fraction around each cell + if sp.issparse(knn_norm): + p1 = knn_norm.dot(expr) + else: + p1 = knn_norm @ expr + p1 = np.asarray(p1, dtype=self._dtype).ravel() + f1 = np.asarray(f1, dtype=self._dtype).ravel() + # compute ltst per-cell then smooth by knn + ltst = ltst_score_func( + np.asarray(f0, dtype=self._dtype), + np.asarray(f1, dtype=self._dtype), + p1, + ) + + if sp.issparse(knn_norm): + local_zscore = knn_norm.dot(ltst) + else: + local_zscore = knn_norm @ ltst + local_zscore = np.asarray(local_zscore, dtype=self._dtype).ravel() + + local_zscore = self._null_distribution.to_zscore(local_zscore, p1) + + K = self.knn() + expr_mask = (gene_prior > 0).astype(np.float32) # (n,) + sig_mask = (f1 > f0).astype(np.float32) # (n,) + + if sp.issparse(K): + expr_neighbors = K.dot(expr_mask) # how many expressing neighbors (weighted) + sig_expr_neighbors = K.dot(expr_mask * sig_mask) # how many expressing-and-signal neighbors + else: + expr_neighbors = K @ expr_mask + sig_expr_neighbors = K @ (expr_mask * sig_mask) + + local_lscore = sig_expr_neighbors / np.maximum(expr_neighbors, 1e-12) + local_lscore = np.asarray(local_lscore).ravel() + + + # rate + # local_zscore = local_zscore * knn_neis / self.n_cells + assert K.shape == (self.n_cells, self.n_cells), (K.shape, self.n_cells) + assert f0.shape == (self.n_cells,), (f0.shape, self.n_cells) + assert f1.shape == (self.n_cells,), (f1.shape, self.n_cells) + + + localization_pval = localization_pvalue_nn_func(gene_prior, f1, f0, K) + concentration_pval = norm.sf(local_zscore) + #localization_pval = 1.0 + #concentration_pval = norm.sf(local_zscore) + + + if np.any(local_zscore > zscore_thresh): + locally_enriched[self._adata.var_names[i_gene]] = { + 'enriched': np.flatnonzero(local_zscore > zscore_thresh), + 'local_zscore': local_zscore, + 'local_lscore': local_lscore, + 'localization_pval': localization_pval, + 'concentration_pval': concentration_pval, + 'local_pvalue': localization_pval + concentration_pval - localization_pval*concentration_pval, + 'n_components': n_comp + } + except Exception as e: + logger.exception(e) # prints full traceback + raise # optional: stop on first failure + + return locally_enriched + + @staticmethod + def calc_lratio(f1, ix, log_bkg_pdf, sample_size, eps=1e-300): + """ + Per-cell LRT contribution on expressing cells: + -2 * sum_{i in ix} ( log f0(i) - log f1(i) ) / sample_size + """ + # ensure 1D arrays + f1 = np.asarray(f1).ravel() + ix = np.asarray(ix, dtype=bool).ravel() + log_bkg_pdf = np.asarray(log_bkg_pdf).ravel() + + # log f1 safely (no mutation) + log_f1 = np.log(np.clip(f1, eps, np.inf)) + + denom = float(sample_size) if sample_size > 0 else 1.0 + return float((-2.0 * np.sum((log_bkg_pdf[ix] - log_f1[ix])) ) / denom) + + def random_pdf( + self, + weights, + n_comp=None, + n_inits=300, + buckets=None, + ): + import numpy as np + from scipy.stats import multivariate_normal as mnorm # <-- add back + + weights = np.asarray(weights, dtype=self._dtype).ravel() + + if n_comp is None: + comp_weights = self._auto_n_effective_weights(weights) + if comp_weights is None: + n_comp = 1 + else: + loc_indices = comp_weights > 0 + n_comp = self.auto_n_components( + self._embedding[loc_indices, :], + weights[loc_indices], + indices=loc_indices, # <-- IMPORTANT + ) + + pis, mus, sigmas, rweights = softbootstrap_gmm( + self._embedding, + raw_weights=weights, + n_components=n_comp, + reg_covar=self.reg_covar(), + n_inits=n_inits, + buckets=buckets, + ) + + pis = np.asarray(pis) + mus = np.asarray(mus) + sigmas = np.asarray(sigmas) + rweights = np.asarray(rweights) + + dtot = np.zeros((self.n_cells, n_inits), dtype=np.float64) + + # NOTE: this loop is slow but correct; optimize later if needed + for j in range(n_inits): + for i in range(n_comp): + w = pis[j, i] + if w == 0: + continue + c0 = mnorm(mean=mus[j, i], cov=sigmas[j, i], allow_singular=True) + dtot[:, j] += w * c0.pdf(self._embedding) + + return dtot.astype(self._dtype, copy=False), rweights.astype(self._dtype, copy=False) + + +# ---------------------------------------------------------------------- +# JIT-accelerated scoring helpers +# ---------------------------------------------------------------------- +@numba.jit(nopython=True) +def ltst_score_func(f0, f1, p): + q = np.sqrt(p) + return (q * (1 - q)) * (f1 - f0) / (f1 + f0) + + +@numba.jit(nopython=True) +def sens_score_func(f0, f1, i): + return np.mean(f1[i > 0] > f0[i > 0]) + +import numpy as np +import scipy.sparse as sp + +def localization_pvalue_nn_func(x1, f1, f0, nn): + """ + Sparse-safe rewrite of the original localization_pvalue_nn_func. + + Preserves: + - i1/obs1/obs2 definitions + - n and o as neighborhood (weighted) counts / signed balance + - f2 weighting for global mu_hat, p1, p0 + - effective_n, sd_hat + - per-cell p via normal_sf(o[i], mu_hat, sd_hat[i]) + + Works for: + - nn sparse CSR/CSC/COO + - nn dense numpy array + """ + x1 = np.asarray(x1) + f1 = np.asarray(f1) + f0 = np.asarray(f0) + + # boolean masks + i1 = (x1 > 0) + obs1 = ((f1 > f0) & i1) + obs2 = ((f1 < f0) & i1) + + # convert to float vectors for dot products + i1_f = i1.astype(np.float32) + obs1_f = obs1.astype(np.float32) + obs2_f = obs2.astype(np.float32) + + # neighborhood weighted counts + if sp.issparse(nn): + # ensure CSR for fast dot + nn = nn.tocsr() + n = nn.dot(i1_f) + o = nn.dot(obs1_f) - nn.dot(obs2_f) + else: + n = nn @ i1_f + o = (nn @ obs1_f) - (nn @ obs2_f) + + n = np.asarray(n).ravel() + o = np.asarray(o).ravel() + + # normalize signed balance + o = o / np.clip(n, 1.0, np.inf) + + # same f2 + global mu_hat logic as original + f2 = 2.0 * (f1 * f0) / np.clip(f1 + f0, 1e-12, np.inf) + + denom = float(np.sum(f2)) + if denom <= 0 or not np.isfinite(denom): + # degenerate fallback: no information => p=1 everywhere + return np.ones(shape=(nn.shape[0],), dtype=np.float64) + + p1 = float(np.sum((f1 > f0) * f2) / denom) + p0 = float(np.sum((f1 < f0) * f2) / denom) + mu_hat = p1 - p0 + + effective_n = np.clip(n * (1.0 + 2.0 * np.abs(o / 2.0)), 1.0, np.inf) + sd_hat = np.sqrt(((p1 + p0) - ((p1 - p0) ** 2)) / effective_n) + + # compute p only where n>0 (same behavior as your loop) + p = np.ones(shape=(nn.shape[0],), dtype=np.float64) + idx = np.flatnonzero(n > 0) + + # vectorized normal_sf (your normal_sf is scalar-numba; keep loop or vectorize in numpy) + # We'll keep the loop for correctness with your numba normal_sf signature. + for i in idx: + p[i] = normal_sf(float(o[i]), mu=mu_hat, sigma=float(sd_hat[i])) + + return p + + +@numba.njit +def normal_sf(x, mu, sigma): + """ + Normal survival function SF = 1 - CDF, numba-jitted. + """ + z = (x - mu) / sigma + return 0.5 * math.erfc(z * 0.7071067811865476) # = 1 - Phi(z) + + +def _logsidak_from_logp(logp_min: float, m_eff: int) -> float: + """ + Sidák combine in log-space: + log p_sidák = log(1 - (1 - p_min)^m_eff) + = log(1 - exp(m_eff * log(1 - p_min))) + with log(1 - p_min) = log1p(-exp(logp_min)). + """ + if m_eff <= 1: + return float(logp_min) + l1mp = np.log1p(-np.exp(logp_min)) + return float(np.log1p(-np.exp(m_eff * l1mp))) + + +@numba.njit +def smooth_qvals(x): + return np.where( + x <= 0.99, + x, + 1 - 0.01 / (1 + 100 * (x - 0.99) + 10000 * (x - 0.99) ** 2), + ) + + +def _safe_p(p, eps=1e-15): + if not np.isfinite(p): + return 1.0 - eps + return float(min(max(p, eps), 1.0 - eps)) + + +def cauchy_combine(pvals, weights=None): + """ + Robust p-value combiner for dependent tests (Cauchy combination). + Liu & Xie (2020). + """ + ps = np.array([_safe_p(p) for p in pvals], dtype=np.float64) + if weights is None: + weights = np.ones_like(ps) + w = np.asarray(weights, dtype=np.float64) + w = w / (w.sum() if w.sum() > 0 else 1.0) + + t = np.sum(w * np.tan((0.5 - ps) * np.pi)) + pc = 0.5 - np.arctan(t) / np.pi + return _safe_p(pc) + + +def summarize_rc_debug(cs_res, top=8): + """ + Convenience helper to inspect per-threshold diagnostics from + localization_pval_dep_scan(..., debug=True). + """ + if "per_c" not in cs_res: + print("No per-threshold diagnostics captured. Run with debug=True.") + return + rows = cs_res["per_c"] + if not rows: + print("No thresholds scanned.") + return + from collections import Counter + + reasons = Counter(r.get("reason") for r in rows) + print("Reason counts:", dict(reasons)) + + cand = [r for r in rows if r.get("reason") is not None] + cand.sort(key=lambda r: r["n_eff_expected"], reverse=True) + print(f"\nTop {min(top, len(cand))} failing thresholds by n_eff_expected:") + for r in cand[:top]: + print( + f" c={r['c']:.3f} p0_abs={r['p0_abs']:.3f} " + f"obs_prop={r['obs_prop']:.3f} n_eff_expected={r['n_eff_expected']:.1f} " + f"reason={r['reason']}" + ) diff --git a/locat/plotting_and_other_methods.py b/locat/plotting_and_other_methods.py new file mode 100755 index 0000000..35e4ff6 --- /dev/null +++ b/locat/plotting_and_other_methods.py @@ -0,0 +1,383 @@ +from matplotlib.colors import LogNorm +from sklearn import mixture +import scanpy as sc +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + +import itertools +def train_clustering(adata, genesuse, pthresh = 0.5): #doesnt need embedding coords, just uses neighbors attr. + import warnings + warnings.filterwarnings(action='ignore', category=UserWarning) + adata = adata[:,genesuse] + data = adata.X#.to_df() # counts + + sc.tl.leiden(adata, resolution=0.25) + #check if log1p exists, if so set = to none to avoid errors + if "log1p" in adata.uns.keys(): + adata.uns['log1p']["base"] = None + else: + print("No log1p uns is found") + from collections import Counter + print(Counter(adata.obs["leiden"])) + un = list(np.unique(adata.obs["leiden"])) + #print("newleiden",newleiden[:10]) + exclude = [] + for i in un: + if ((adata[adata.obs["leiden"]==i,:].copy().shape[0]) < 2): + exclude.append((i)) + include = [i for i in un if i not in exclude] + sc.tl.rank_genes_groups(adata, 'leiden', method='wilcoxon', groups = include) + names = np.array(adata.uns['rank_genes_groups']['names'].tolist()).flatten() + pvals = np.array(adata.uns['rank_genes_groups']['pvals'].tolist()).flatten() + logfcs = np.array(adata.uns['rank_genes_groups']['logfoldchanges'].tolist()).flatten() + #print(adata.uns['rank_genes_groups']['names']) + + maxnames = [] + minnames = [] + maxpvals = [] + maxlogfcs = [] + notsignames = [] + minpvals=[] + #for each unique value + for i in np.unique(names): + #get indexes of this value + inds = list(np.where(names == i)) + #look up which have pvals <.05 + #update + if (pthresh!=0.5): + + i_names = list(itertools.compress(names[inds].tolist()[0],(pvals[inds]<100).tolist()[0])) + i_pvals = list(itertools.compress(pvals[inds].tolist()[0],(pvals[inds]<100).tolist()[0])) + lowest = sorted(i_pvals, reverse=False)[0] + minind = np.where(np.array(i_pvals)==lowest)[0][0] + minnames.append(i_names[minind]) + minpvals.append(i_pvals[minind]) + continue + if (np.sum(pvals[inds]<0.05)>0): + #print(np.array(names)[inds]) + #print(pvals[inds]<0.05) + i_names = list(itertools.compress(names[inds].tolist()[0],(pvals[inds]<0.05).tolist()[0])) + i_pvals = list(itertools.compress(pvals[inds].tolist()[0],(pvals[inds]<0.05).tolist()[0])) + i_logfcs = list(itertools.compress(logfcs[inds].tolist()[0],(pvals[inds]<0.05).tolist()[0])) + #get new list of indices + + inds = list(range(len(i_names))) + + #lookup which has highest logfc + + highest = sorted(i_logfcs, reverse=True)[0] + lowest = sorted(i_pvals, reverse=False)[0] + #get maximum index + maxind = np.where(np.array(i_logfcs)==highest)[0][0] + minind = np.where(np.array(i_pvals)==lowest)[0][0] + #print("maxind",maxind) + #update arrays + maxnames.append(i_names[maxind]) + minnames.append(i_names[minind]) + maxpvals.append(i_pvals[maxind]) + minpvals.append(i_pvals[minind]) + maxlogfcs.append(i_logfcs[maxind]) + else: + notsignames.append(i) + debugging = [maxlogfcs,maxnames] + sortinds = [np.argsort(maxlogfcs)[::-1]] + maxlogfcs = np.array(maxlogfcs)[sortinds] + #print(maxlogfcs[:10]) + maxpvals = np.array(maxpvals)[sortinds] + maxnames = np.array(maxnames)[sortinds] + sortinds = [np.argsort(minpvals)]#[::-1]] + minnames = np.array(minnames)[sortinds] + print("maxnames_len",len(maxnames)) + return (maxlogfcs,maxpvals,maxnames,minnames, notsignames, sortinds) + +def train_clustering_logpadj(adata, genesuse, resolution=0.25, method="wilcoxon"): + ad = adata[:, genesuse].copy() + if "connectivities" not in ad.obsp: + raise ValueError("Neighbors missing: run sc.pp.neighbors first (use_rep='coords' etc).") + + if np.min(ad.X) >= 0: + sc.pp.log1p(ad) + + sc.tl.leiden(ad, resolution=resolution) + if "log1p" in ad.uns: + ad.uns["log1p"]["base"] = None + + sizes = ad.obs["leiden"].value_counts() + groups = sizes.index[sizes >= 2].tolist() + if not groups: + raise ValueError("All Leiden clusters have <2 cells.") + + sc.tl.rank_genes_groups(ad, "leiden", groups=groups, method=method) + + rg = ad.uns["rank_genes_groups"] + names = pd.DataFrame(rg["names"]).loc[:, groups].to_numpy().T + padj = pd.DataFrame(rg["pvals_adj"]).loc[:, groups].to_numpy().T + if "logfoldchanges" in rg: + lfc = pd.DataFrame(rg["logfoldchanges"]).loc[:, groups].to_numpy().T + else: + lfc = np.full(padj.shape, np.nan, dtype=float) + + n_groups, n_top = names.shape + df = pd.DataFrame({ + "gene": names.ravel(), + "group": np.repeat(np.array(groups, dtype=object), n_top), + "pvals_adj": padj.ravel(), + "logfoldchange": lfc.ravel(), + }).dropna(subset=["gene", "pvals_adj"]).sort_values(["gene", "pvals_adj"], kind="mergesort") + + best = df.drop_duplicates("gene", keep="first").set_index("gene") + best["log10_padj"] = -np.log10(np.maximum(best["pvals_adj"].to_numpy(), 1e-300)) + return best.sort_values("pvals_adj") + +def parse_output(adata,res,localres): + #get vector of wstats for every gene + #also, get expression volume contained in each gene + keys = list(res.keys()) + varnames = list(adata.var_names) + wstats = [j["local_zscore"] for i,j in localres.items()] + + enrvols = [] + enrspecs = [] + + for wst_i, c in zip(wstats,range(len(keys))): + enr_ind = np.array(wst_i) > 0 + #specificity: proportion of enriched cells that are expressing the gene + enrspecs.append( np.sum(np.logical_and(adata.X[:,c]>0,enr_ind))/np.sum(enr_ind) ) + #sensitivity: proportion of total expression volume accounted for by enriched cells + enrvols.append( np.sum(adata.X[:,c] * (enr_ind/1))/np.sum(adata.X[:,c]) ) + + return(np.array(wstats).T, enrvols, enrspecs) + +def plot_gene_localization_summary( + genes, + locat_df, + adata, + suptitle="Gene Localization Summary", + embedding_key="X_umap", + embedding_dims=2 +): + """ + Plots expression, GMM fit, and localized masks for each gene. + """ + umap = adata.obsm[embedding_key][:, :embedding_dims] + n_genes = len(genes) + + # One row per gene, 3 plots per row + fig, axes = plt.subplots( + nrows=n_genes, + ncols=3, + figsize=(15, 5 * n_genes), + squeeze=False + ) + + for i, gene in enumerate(genes): + # Extract data from locat_df + rec = locat_df.loc[gene] + gmm = rec["gmm"] + gene_prior = rec["gene_prior"] + m_cuts = rec["m_cuts"] + components_to_use = rec["components_to_use"] + + # Expression vector + expr = gene_prior + is_expressing = expr > 0 + + # GMM density + density = gmm.pdf(umap) + + # Mahalanobis distances + m0 = gmm.mahalanobis_dist(umap) + + # Compute is_localized mask + is_localized = ~np.all(m0 > m_cuts[None, :], axis=1) + + # Localized vs unlocalized masks + loc_mask = is_expressing & is_localized + unloc_mask = is_expressing & (~is_localized) + + ### 1️⃣ Expression plot (ordered by expression) + sort_idx = np.argsort(expr) + ax = axes[i, 0] + sc = ax.scatter( + umap[sort_idx, 0], + umap[sort_idx, 1], + c=expr[sort_idx], + cmap="viridis", + s=60, + edgecolor="none" + ) + ax.set_title(f"{gene}\nExpression", fontsize=12) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_aspect("equal") + cbar = fig.colorbar(sc, ax=ax, shrink=0.7) + cbar.set_label("Expression") + + ### 2️⃣ GMM density plot (ordered by density) + sort_idx_dens = np.argsort(density) + ax = axes[i, 1] + sc = ax.scatter( + umap[sort_idx_dens, 0], + umap[sort_idx_dens, 1], + c=density[sort_idx_dens], + cmap="magma", + s=20, + edgecolor="none" + ) + ax.set_title(f"{gene}\nGMM Density", fontsize=12) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_aspect("equal") + cbar = fig.colorbar(sc, ax=ax, shrink=0.7) + cbar.set_label("Density") + + ### 3️⃣ Localized vs unlocalized plot (order: background, unloc, loc) + ax = axes[i, 2] + # Background + ax.scatter( + umap[:, 0], + umap[:, 1], + c="lightgrey", + s=10, + alpha=0.3, + edgecolor="none" + ) + # Unlocalized + ax.scatter( + umap[unloc_mask, 0], + umap[unloc_mask, 1], + c="blue", + s=25, + edgecolor="none", + label="Unlocalized" + ) + # Localized + ax.scatter( + umap[loc_mask, 0], + umap[loc_mask, 1], + c="red", + s=35, + edgecolor="black", + linewidth=0.3, + label="Localized" + ) + ax.set_title( + ( + f"{gene}\nLocalized: {loc_mask.sum()} | " + f"Unloc: {unloc_mask.sum()} | " + f"ExpUnloc: {rec.get('expected_unlocalized', np.nan):.1f}\n" + f"LocPval: {rec['localization_pval']:.2e}" + ), + fontsize=11 + ) + ax.legend(loc="upper right", fontsize=8, frameon=True) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_aspect("equal") + + fig.suptitle(suptitle, fontsize=18, y=1.02) + plt.tight_layout() + fig.subplots_adjust(hspace=0.4) + plt.show() + + +def plotgenes( + adata, + d0, + topgenes, + suptitle="TITLEARG", + size=10, + emb="X_umap", + genes_per_row=5, + text_size=12, + geneinf = False +): + + + # Convert gene list + genes = np.array(topgenes) + n_genes = len(genes) + + # UMAP coordinates + umap = adata.obsm[emb] + + # Grid size + ncols = genes_per_row + nrows = int(np.ceil(n_genes / ncols)) + + # Create figure + fig, axes = plt.subplots( + nrows=nrows, + ncols=ncols, + figsize=(ncols * 4, nrows * 4), + squeeze=False + ) + + # Flatten axes + axes_flat = axes.flatten() + + # Loop through genes + for i, (gene, ax) in enumerate(zip(genes, axes_flat)): + expr = adata[:, gene].X + + # Convert sparse matrix if needed + if not isinstance(expr, np.ndarray): + expr = expr.toarray().flatten() + else: + expr = expr.flatten() + + # Order cells so high-expression on top + order = np.argsort(expr) + x = umap[order, 0] + y = umap[order, 1] + c = expr[order] + + sc = ax.scatter( + x, y, + c=c, + cmap="viridis", + s=size, + edgecolor="none" + ) + + # Multiline title + if geneinf: + ax.set_title( + f"{gene}\n" + f"pval: {d0.loc[gene]['pval']:.2e}\n" + f"conc_pval: {d0.loc[gene]['concentration_pval']:.2e}\n" + f"loca_pval: {d0.loc[gene]['localization_pval']:.2e}", + fontsize=text_size, + pad=10 + ) + else: + ax.set_title( + f"{gene}\n" + f"pval: {d0.loc[gene]['pval']:.2e}\n", + fontsize=text_size, + pad=2 + ) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_aspect("equal") + + # Colorbar for each subplot + #cbar = fig.colorbar(sc, ax=ax, shrink=0.7) + #cbar.set_label("Expression", fontsize=text_size-2) + + # Remove unused axes + for j in range(i + 1, len(axes_flat)): + fig.delaxes(axes_flat[j]) + + # Add suptitle + fig.suptitle( + suptitle, + fontsize=text_size + 6, + y=1.02 + ) + + plt.tight_layout() + fig.subplots_adjust(hspace=0.4) + plt.show() + diff --git a/locat/rgmm.py b/locat/rgmm.py new file mode 100755 index 0000000..9e4d7da --- /dev/null +++ b/locat/rgmm.py @@ -0,0 +1,207 @@ +from functools import partial +import numpy as np +import jax +import jax.numpy as jnp +import jax.scipy as jsp +import tensorflow_probability.substrates.jax as jaxp + +jax.config.update("jax_enable_x64", True) + +from sklearn.cluster import kmeans_plusplus +from loguru import logger + +for d in jax.local_devices(): + logger.info(f'Found device: {d}') +jaxd = jaxp.distributions + + +def _weighted_kmeans_init(X, w, n_c, n_inits): + return map( + lambda i: kmeans_plusplus(X, n_clusters=n_c, sample_weight=w[:, i])[0], + np.arange(n_inits),) + + +def weighted_gmm_init(X, w, n_c, n_inits): + n = X.shape[-1] + mu_init = jnp.array(list(_weighted_kmeans_init(X, w, n_c, n_inits))) + sigma_init = jnp.tile(jnp.eye(n)[None, ...], (n_inits, n_c, 1, 1)) + return mu_init, sigma_init + + +@jax.jit +def e_step(X, pi, mu, sigma): + mixture_log_prob = jaxd.MultivariateNormalTriL( + loc=mu, + scale_tril=jnp.linalg.cholesky(sigma) + ).log_prob(X[:, None, ...]) + jnp.log(pi) + log_membership_weight = mixture_log_prob - jsp.special.logsumexp(mixture_log_prob, axis=-1, keepdims=True) + return jnp.exp(log_membership_weight) + + +@jax.jit +def weighted_m_step(X, membership_weight, sample_weights, reg_covar): + n, m = X.shape + w = membership_weight * sample_weights[..., None] + w_sum = w.sum(0) + pi_updated = w_sum / n + pi_updated /= np.sum(pi_updated) + w = w / w_sum + + mu_updated = jnp.sum( + X[:, None, ...] * w[..., None], + axis=0) + + centered_x = X[:, None, ...] - mu_updated + + sigma_updated = jnp.sum( + jnp.einsum('...i,...j->...ij', centered_x, centered_x) * + w[..., None, None], + axis=0) + + sigma_updated = sigma_updated + jnp.diag(jnp.ones(shape=(m,))*reg_covar)[None, :] + + return pi_updated, mu_updated, sigma_updated + + +@jax.jit +def compute_loss(X, pi, mu, sigma, membership_weight): + component_log_prob = jaxd.MultivariateNormalTriL( + loc=mu, + scale_tril=jnp.linalg.cholesky(sigma) + ).log_prob(X[:, None, ...]) + + loss = membership_weight * ( + jnp.log(pi) + component_log_prob - jnp.log( + jnp.clip(membership_weight, + a_min=jnp.finfo(np.float64).eps))) + return jnp.sum(loss) + + +@partial(jax.jit, static_argnums=(4, 5, 6, 7, 8, 9)) +def train_em(X, samples_weights, mu_init, sigma_init, n_components, + n_inits=25, reg_covar=0.0, rtol=1e-6, max_iter=500, seed=1): + + def cond_fn(state): + i, thetas, loss, loss_diff, _ = state + return jnp.all((i < max_iter) & (loss_diff > rtol)) + + @jax.vmap + def one_step(state): + i, (pi, mu, sigma), loss, loss_diff, sample_weights = state + membership_weight = e_step(X, pi, mu, sigma) + + pi_updated, mu_updated, sigma_updated = weighted_m_step(X, membership_weight, sample_weights, reg_covar) + loss_updated = compute_loss( + X, pi_updated, mu_updated, sigma_updated, membership_weight) + loss_diff = jnp.abs((loss_updated / loss) - 1.) + + return (i + 1, + (pi_updated, mu_updated, sigma_updated), + loss_updated, + loss_diff, + sample_weights) + + key = jax.random.PRNGKey(seed) + raw_pi_init = jax.random.uniform(key, shape=(n_inits, n_components)) + pi_init = raw_pi_init / raw_pi_init.sum(-1, keepdims=True) + key, subkey = jax.random.split(key) + + init_val = (jnp.zeros([n_inits], jnp.int32), + (pi_init, mu_init, sigma_init), + -jnp.ones([n_inits]) * jnp.inf, + jnp.ones([n_inits]) * jnp.inf, + samples_weights.T) + + num_iter, (pi_est, mu_est, sigma_est), loss, loss_diff, _ = jax.lax.while_loop(cond_fn, one_step, init_val) + + # index = jnp.argmax(loss) + # pi_best, mu_best, sigma_best = jax.tree_map(lambda x: x[index], (pi_est, mu_est, sigma_est)) + + return pi_est, mu_est, sigma_est + + +def softbootstrap_gmm(X, raw_weights, n_components, n_inits=100, reg_covar=0.0, seed=1, buckets=None): + if buckets is None: + buckets = np.clip(int(len(raw_weights)/30), 3, 30) + + rand_weights = raw_weights + o_fwd = np.argsort(raw_weights) + o_back = np.argsort(o_fwd) + + rand_weights = np.repeat(rand_weights[o_fwd, None], n_inits, axis=1) + rng = np.random.default_rng(seed) + rand_weights = np.concatenate([rng.permuted(i, axis=0) for i in np.array_split(rand_weights, buckets, axis=0)], axis=0) + rand_weights = rand_weights[o_back,:] + + n = rand_weights.shape[0] + boot_weights = np.random.geometric(1 / n, size=rand_weights.shape) + boot_weights = (boot_weights / np.sum(boot_weights, axis=0)[None, :]) + weights = rand_weights * boot_weights + return rgmm(X, weights, n_components, n_inits, reg_covar, rand_weights) + + +def hardbootstrap_gmm(X, raw_weights, n_components, fraction, n_inits=30, reg_covar=0.0, seed=1): + """ + fraction: what proportion of items to sample + """ + norm_weights = raw_weights / np.sum(raw_weights) + fraction = np.clip(fraction, 0, 1) + n_points = X.shape[0] + n_samples = np.maximum(1, int(n_points * fraction)) + + weights = np.zeros(shape=(n_points, n_inits)) + for i in range(n_inits): + sampled_indices = np.random.choice( + n_points, + size=n_samples, + replace=True, + p=norm_weights + ) + i0, c0 = np.unique(sampled_indices, return_counts=True) + weights[i0, i] = c0 # raw_weights[sampled_indices] + + return rgmm(X, weights, n_components, n_inits, reg_covar) + + +def simplebootstrap_gmm(X, n_components, fraction, n_inits=30, reg_covar=0.0, seed=1): + """ + fraction: what proportion of items to sample + """ + fraction = np.clip(fraction, 0, 1) + n_points = X.shape[0] + n_samples = np.maximum(1, int(n_points * fraction)) + + weights = np.zeros(shape=(n_points, n_inits)) + for i in range(n_inits): + sampled_indices = np.random.choice( + n_points, + size=n_samples, + replace=False, + ) + weights[sampled_indices, i] = 1/n_samples # raw_weights[sampled_indices] + + return rgmm(X, weights, n_components, n_inits, reg_covar) + + +def rgmm(X, weights, n_components, n_inits, reg_covar, true_weights=None): + if true_weights is None: + true_weights = weights + weights = weights / np.sum(weights) + mu_init, sigma_init = weighted_gmm_init(X, + w=weights, + n_c=n_components, + n_inits=n_inits) + pi_est, mu_est, sigma_est = train_em(X, + samples_weights=weights, + mu_init=mu_init, + sigma_init=sigma_init, + n_components=n_components, + n_inits=n_inits, + reg_covar=reg_covar, + rtol=1e-6, + max_iter=500, + seed=1) + + return np.array(pi_est), np.array(mu_est), np.array(sigma_est), true_weights + + diff --git a/locat/wgmm.py b/locat/wgmm.py new file mode 100755 index 0000000..73e339e --- /dev/null +++ b/locat/wgmm.py @@ -0,0 +1,93 @@ +import numpy as np +from scipy.stats import multivariate_normal as mnorm +from dataclasses import dataclass + + +@dataclass +class WGMM: + pis: np.ndarray + mus: np.ndarray + sigmas: np.ndarray + + @property + def n_comp(self) -> int: + return len(self.pis) + + def pdf(self, coords): + pdf = np.zeros(shape=(coords.shape[0],)) + for i in range(self.n_comp): + c0 = mnorm(mean=self.mus[i], cov=self.sigmas[i]) + pdf += self.pis[i] * c0.pdf(coords) + return pdf + + def loglikelihood_by_component(self, coords, weights): + # estimate the pdf of each component separately + logpdf = np.zeros(shape=(coords.shape[0], self.n_comp)) + for i in range(self.n_comp): + c0 = mnorm(mean=self.mus[i], cov=self.sigmas[i]) + logpdf[:, i]= c0.logpdf(coords) * weights + + # estimate the log-likelihood of each component + return logpdf + + def mahalanobis_dist(self, coords): + """ + Compute Mahalanobis distance from each point to each component's peak (vectorized). + + Parameters: + ----------- + coords : np.ndarray + Input coordinates array with shape (n_points, n_dimensions) + + Returns: + -------- + np.ndarray + Mahalanobis distances with shape (n_points, n_components) + Each row represents a point, each column represents a component + """ + n_points = coords.shape[0] + distances = np.zeros((n_points, self.n_comp)) + + for i in range(self.n_comp): + # Get mean and covariance for component i + mean = self.mus[i] + cov = self.sigmas[i] + + # Compute inverse of covariance matrix with regularization + try: + inv_cov = np.linalg.inv(cov) + except np.linalg.LinAlgError: + # Add small regularization term for numerical stability + reg_cov = cov + np.eye(cov.shape[0]) * 1e-6 + inv_cov = np.linalg.inv(reg_cov) + + # Vectorized computation for all points + diff = coords - mean[None,:] # Broadcasting: (n_points, n_dims) - (n_dims,) + + # Compute quadratic form: (x-μ)ᵀ Σ⁻¹ (x-μ) for each point + quad_form = np.sum((diff @ inv_cov) * diff, axis=1) + distances[:, i] = np.sqrt(quad_form) + + return distances + + def loglikelihood_truncated(self, coords, weights, top_n_components=None): + if top_n_components is None: + top_n_components = self.n_comp + + # speeding up + coords = coords[weights>0] + weights = weights[weights>0] + + # top n components by number of explained points + logpdf = self.loglikelihood_by_component(coords=coords, weights=weights) + + # Get indices of top n components using argpartition (efficient) + if top_n_components...ij', centered_x, centered_x) * + w[..., None, None], + axis=0) + + sigma_updated = sigma_updated + jnp.diag(jnp.ones(shape=(m,))*reg_covar)[None, :] + + return pi_updated, mu_updated, sigma_updated + + +@jax.jit +def compute_loss(X, pi, mu, sigma, membership_weight): + component_log_prob = jaxd.MultivariateNormalTriL( + loc=mu, + scale_tril=jnp.linalg.cholesky(sigma) + ).log_prob(X[:, None, ...]) + + loss = membership_weight * ( + jnp.log(pi) + component_log_prob - jnp.log( + jnp.clip(membership_weight, + a_min=jnp.finfo(np.float64).eps))) + return jnp.sum(loss) + + +@partial(jax.jit, static_argnums=(4, 5, 6, 7, 8, 9)) +def train_em(X, sample_weights, mu_init, sigma_init, n_components, + n_inits=25, reg_covar=0.0, rtol=1e-6, max_iter=500, seed=1): + def cond_fn(state): + i, thetas, loss, loss_diff = state + return jnp.all((i < max_iter) & (loss_diff > rtol)) + + @jax.vmap + def one_step(state): + i, (pi, mu, sigma), loss, loss_diff = state + membership_weight = e_step(X, pi, mu, sigma) + + pi_updated, mu_updated, sigma_updated = weighted_m_step(X, membership_weight, sample_weights, reg_covar) + loss_updated = compute_loss( + X, pi_updated, mu_updated, sigma_updated, membership_weight) + loss_diff = jnp.abs((loss_updated / loss) - 1.) + + return (i + 1, + (pi_updated, mu_updated, sigma_updated), + loss_updated, + loss_diff) + + key = jax.random.PRNGKey(seed) + raw_pi_init = jax.random.uniform(key, shape=(n_inits, n_components)) + pi_init = raw_pi_init / raw_pi_init.sum(-1, keepdims=True) + key, subkey = jax.random.split(key) + + init_val = (jnp.zeros([n_inits], jnp.int32), + (pi_init, mu_init, sigma_init), + -jnp.ones([n_inits]) * jnp.inf, + jnp.ones([n_inits]) * jnp.inf) + + num_iter, (pi_est, mu_est, sigma_est), loss, loss_diff = jax.lax.while_loop(cond_fn, one_step, init_val) + + index = jnp.argmax(loss) + pi_best, mu_best, sigma_best = jax.tree.map(lambda x: x[index], (pi_est, mu_est, sigma_est)) + + return pi_best, mu_best, sigma_best, loss + + +def wgmm(X, raw_weights, n_components, n_inits=1, reg_covar=0.0): + norm_weights = raw_weights / np.sum(raw_weights) + mu_init, sigma_init = weighted_gmm_init(X, + w=norm_weights, + n_c=n_components, + n_inits=n_inits) + return train_em(X, + sample_weights=norm_weights, + mu_init=mu_init, + sigma_init=sigma_init, + n_components=n_components, + n_inits=n_inits, + reg_covar=reg_covar, + rtol=1e-6, + max_iter=500, + seed=1) + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..20c156d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,54 @@ +[build-system] +requires = ["flit_core", "flit_scm"] +build-backend = "flit_scm:buildapi" + +[project] +name = "locat" +dynamic = ['version'] +description = "Locat identifies marker genes enriched within a compact subset of cells and depleted elsewhere" +license = {file = "LICENSE"} +readme = "README.md" + +authors = [ + {name = "Wes Lewis", email = "wes.lewis@yale.edu"}, + {name = "Fabio Parisi", email = "fabio.parisi@pcmgf.com"}, + {name = "Francesco Strino", email = "francesco.strino@pcmgf.com"}, + {name = "Yuval Kluger", email = "yuval.kluger@yale.edu"}, +] +maintainers = [ + {name = "Wes Lewis", email = "wes.lewis@yale.edu"}, + {name = "Fabio Parisi", email = "fabio.parisi@pcmgf.com"}, + {name = "Francesco Strino", email = "francesco.strino@pcmgf.com"}, +] + +keywords = ["Locat", "Localized genes"] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "License :: OSI Approved :: BSD License", + "Intended Audience :: Science/Research", +] + +requires-python = ">=3.10" +dependencies = [ + "scanpy", + "loguru", + "jax", + "tensorflow_probability", +] + +[tool.flit.sdist] +exclude = [".gitignore"] + +[tool.setuptools_scm] +write_to = "locat/_version.py" + +[project.optional-dependencies] +dev = ["pytest", "flit_core", "flit_scm"] +cuda12 = ["jax[cuda12]"] +cuda13 = ["jax[cuda13]"] + +[project.urls] +Documentation = "https://github.com/KlugerLab/Locat" +Repository = "https://github.com/KlugerLab/Locat" +"Bug Tracker" = "https://github.com/KlugerLab/Locat/issues" From 480478d8dac814ebdb5ef2b68c3f3db9991044bd Mon Sep 17 00:00:00 2001 From: Francesco Strino Date: Tue, 10 Feb 2026 13:50:31 +0200 Subject: [PATCH 2/2] Github workflows --- .github/workflows/build_and_test.yml | 30 ++++++++ .github/workflows/publish_documentation.yml | 46 ++++++++++++ .github/workflows/publish_github_release.yml | 74 ++++++++++++++++++++ .github/workflows/publish_pypi_release.yml | 58 +++++++++++++++ 4 files changed, 208 insertions(+) create mode 100644 .github/workflows/build_and_test.yml create mode 100644 .github/workflows/publish_documentation.yml create mode 100644 .github/workflows/publish_github_release.yml create mode 100644 .github/workflows/publish_pypi_release.yml diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml new file mode 100644 index 0000000..15e63cb --- /dev/null +++ b/.github/workflows/build_and_test.yml @@ -0,0 +1,30 @@ +# See e.g. https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python +name: Build and test + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +defaults: + run: + shell: bash -l {0} +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Test with pytest + run: | + pip install pytest pytest-cov + pytest tests diff --git a/.github/workflows/publish_documentation.yml b/.github/workflows/publish_documentation.yml new file mode 100644 index 0000000..d379b2f --- /dev/null +++ b/.github/workflows/publish_documentation.yml @@ -0,0 +1,46 @@ +# Adapted from https://coderefinery.github.io/documentation/gh_workflow/ +name: documentation + +on: + push: + tags: + - 'v*' + workflow_dispatch: + +permissions: + contents: write + pages: write + id-token: write + +concurrency: + group: "pages" + cancel-in-progress: false + +jobs: + docs: + runs-on: ubuntu-latest + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r docs/requirements.txt + pip install -e . + - name: Sphinx build + run: | + sphinx-build docs/source _build + - name: Setup Pages + uses: actions/configure-pages@v5 + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 + with: + path: '_build' + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/publish_github_release.yml b/.github/workflows/publish_github_release.yml new file mode 100644 index 0000000..abdcd30 --- /dev/null +++ b/.github/workflows/publish_github_release.yml @@ -0,0 +1,74 @@ +# +# Create a new GitHub release (do not run if one already exists!). +# +# Adapted from https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/ +name: Publish Python distribution to Github + +on: + workflow_dispatch: + +jobs: + build: + name: Build distribution 📦 + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install pypa/build + run: >- + python3 -m + pip install + build + --user + - name: Build a binary wheel and a source tarball + run: python3 -m build + - name: Store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + github-release: + name: >- + Sign the Python distribution with Sigstore + and upload them to GitHub Release + runs-on: ubuntu-latest + + permissions: + contents: write # IMPORTANT: mandatory for making GitHub Releases + id-token: write # IMPORTANT: mandatory for sigstore + + steps: + - name: Download all the dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + - name: Sign the dists with Sigstore + uses: sigstore/gh-action-sigstore-python@v3.0.0 + with: + inputs: >- + ./dist/*.tar.gz + ./dist/*.whl + - name: Create GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + run: >- + gh release create + '${{ github.ref_name }}' + --repo '${{ github.repository }}' + --notes "" + - name: Upload artifact signatures to GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + # Upload to GitHub Release using the `gh` CLI. + # `dist/` contains the built packages, and the + # sigstore-produced signatures and certificates. + run: >- + gh release upload + '${{ github.ref_name }}' dist/** + --repo '${{ github.repository }}' diff --git a/.github/workflows/publish_pypi_release.yml b/.github/workflows/publish_pypi_release.yml new file mode 100644 index 0000000..e2aed31 --- /dev/null +++ b/.github/workflows/publish_pypi_release.yml @@ -0,0 +1,58 @@ +# +# Creates a new PyPI release when a tag is pushed from git +# - trigger only when a tag pushed, use the format 'v1.2.3' +# - uploads the new version on PyPI https://pypi.org/p/dyson-equalizer +# Adapted from https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/ +name: Publish Python distribution to PyPI + +on: + push: + tags: + - 'v*' + workflow_dispatch: + +jobs: + build: + name: Build distribution + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install pypa/build + run: >- + python3 -m + pip install + build + --user + - name: Build a binary wheel and a source tarball + run: python3 -m build + - name: Store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + publish-to-pypi: + name: >- + Publish Python distribution to PyPI + needs: + - build + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/dyson-equalizer + permissions: + id-token: write + + steps: + - name: Download all the dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + - name: Publish distribution 📦 to PyPI + uses: pypa/gh-action-pypi-publish@release/v1