diff --git a/build/lib/snapseed/__init__.py b/build/lib/snapseed/__init__.py new file mode 100644 index 0000000..8b6ae0a --- /dev/null +++ b/build/lib/snapseed/__init__.py @@ -0,0 +1 @@ +from .annotate import annotate, annotate_hierarchy diff --git a/build/lib/snapseed/annotate.py b/build/lib/snapseed/annotate.py new file mode 100644 index 0000000..2c2a1ab --- /dev/null +++ b/build/lib/snapseed/annotate.py @@ -0,0 +1,137 @@ +import pandas as pd + +from .trinarize import annotate_cytograph +from .auroc import annotate_snap +from .degenes import annotate_degenes + +from .utils import get_markers, get_annot_df + + +def annotate_hierarchy( + adata, + marker_hierarchy, + group_name, + method="auroc", + layer=None, + min_expr=0.1, + **kwargs +): + """ + Annotate clusters based on a manually defined cell type and marker hierarchy. + + Parameters + ---------- + adata + AnnData object + marker_hierarchy + Dict with marker genes for each celltype arranged hierarchically. + group_name + Name of the column in adata.obs that contains the cluster labels + method + Method to use for annotation. Options are "auroc" and "trinatize". + layer + Layer in adata to use for expression + **kwargs + Additional arguments to pass to the annotation function. + """ + + # Annotate at each level of the hierarchy + assignment_hierarchy = annotate_levels( + adata, marker_hierarchy, group_name, method=method + ) + + return dict( + assignments=get_annot_df(assignment_hierarchy, group_name, min_expr=min_expr), + metrics=assignment_hierarchy, + ) + + +def annotate_levels( + adata, + marker_hierarchy, + group_name, + level=0, + assignment_levels=None, + method="auroc", + layer=None, +): + """Recursively annotatates all levels of a marker hierarchy.""" + level += 1 + level_name = "level_" + str(level) + marker_dict = get_markers(marker_hierarchy) + assignments = annotate(adata, marker_dict, group_name, method=method, layer=layer, level=level) + + if assignment_levels is None: + assignment_levels = {} + + if level_name not in assignment_levels.keys(): + assignment_levels[level_name] = pd.DataFrame() + + assignment_levels[level_name] = pd.concat( + [assignment_levels[level_name], assignments], axis=0 + ) + + for subtype in assignments["class"].unique(): + if subtype == 'na': + continue + + if "subtypes" not in marker_hierarchy[subtype].keys(): + continue + + # Subset adata + subtype_groups = assignments[group_name][ + assignments["class"] == subtype + ].astype(str) + subtype_adata = adata[adata.obs[group_name].isin(subtype_groups)] + + # Recursively annotate + assignment_levels = annotate_levels( + subtype_adata, + marker_hierarchy[subtype]["subtypes"], + group_name, + level=level, + assignment_levels=assignment_levels, + method=method, + layer=layer, + ) + + return assignment_levels + + +def annotate(adata, marker_dict, group_name, method="auroc", layer=None, level=None, **kwargs): + """ + Annotate clusters based on a manually defined cell type markers. + + Parameters + ---------- + adata + AnnData object + marker_dict + Dict with marker genes for each celltype + group_name + Name of the column in adata.obs that contains the cluster labels + method + Method to use for annotation. Options are "auroc" and "trinatize". + layer + Layer in adata to use for expression + **kwargs + Additional arguments to pass to the annotation function. + """ + + if method == "auroc": + assignments = annotate_snap( + adata, marker_dict, group_name, layer=layer, **kwargs + ) + elif method == "trinatize": + assignments = annotate_cytograph( + adata, marker_dict, group_name, layer=layer, **kwargs + ) + elif method == "degenes": + assignments = annotate_degenes( + adata, marker_dict, group_name, layer=layer, level=None, **kwargs + ) + else: + raise ValueError("Unknown annotation method.") + # Join cluster-level results with adata + assignments = assignments.reset_index(names=group_name) + return assignments diff --git a/build/lib/snapseed/auroc.py b/build/lib/snapseed/auroc.py new file mode 100644 index 0000000..bd1d787 --- /dev/null +++ b/build/lib/snapseed/auroc.py @@ -0,0 +1,124 @@ +import pandas as pd + +import jax +from jax import numpy as jnp + +from functools import partial +from sklearn import preprocessing + +from .utils import dict_to_binary, get_expr, frac_nonzero, masked_max + + +def annotate_snap( + adata, + marker_dict, + group_name, + layer=None, + auc_weight=0.5, + expr_weight=0.5, +): + """ + Annotate cell types based on AUROC and expression of predefined marker genes. + + Parameters + ---------- + adata + AnnData object + marker_dict + Dict with marker genes for each celltype + group_name + Name of the column in adata.obs that contains the cluster labels + layer + Layer in adata to use for expression + """ + # Reformat marker_dict into binary matrix + marker_mat = dict_to_binary(marker_dict) + + # Compute AUROC and fraction nonzero for marker features + features = marker_mat.columns + metrics = auc_expr(adata, group_name, features=features) + + marker_mat = marker_mat.loc[:, metrics["features"]] + auc_max = masked_max(metrics["auroc"], marker_mat.values) + expr_max = masked_max(metrics["frac_nonzero"], marker_mat.values) + # Combine metrics + assignment_scores = (auc_weight * auc_max + expr_weight * expr_max) / ( + auc_weight + expr_weight + ) + assign_idx = jnp.argmax(assignment_scores, axis=0) + # Mask out genes that are not expressed in any cell + assign_class = marker_mat.index[assign_idx] + + assign_df = pd.DataFrame( + { + "class": assign_class, + "score": assignment_scores[assign_idx, jnp.arange(auc_max.shape[1])], + "auc": auc_max[assign_idx, jnp.arange(auc_max.shape[1])], + "expr": expr_max[assign_idx, jnp.arange(expr_max.shape[1])], + }, + index=metrics["groups"], + ) + + return assign_df + + +def auc_expr(adata, group_name, features=None, layer=None): + """Computes AUROC and fraction nonzero for each gene in an adata object.""" + # Turn string groups into integers + le = preprocessing.LabelEncoder() + le.fit(adata.obs[group_name]) + + # Compute AUROC and fraction nonzero + groups = jnp.array(le.transform(adata.obs[group_name])) + expr, features = get_expr(adata, features=features, layer=layer) + auroc, frac_nonzero = expr_auroc_over_groups(expr, groups) + + return dict( + frac_nonzero=frac_nonzero, + auroc=auroc, + features=features, + groups=le.classes_, + ) + + +@jax.jit +@partial(jax.vmap, in_axes=[1, None]) +def jit_auroc(x, groups): + # TODO: compute frac nonzero here to avoid iterating twice + + # sort scores and corresponding truth values + desc_score_indices = jnp.argsort(x)[::-1] + x = x[desc_score_indices] + groups = groups[desc_score_indices] + + # x typically has many tied values. Here we extract + # the indices associated with the distinct values. We also + # concatenate a value for the end of the curve. + distinct_value_indices = jnp.array(jnp.diff(x) != 0, dtype=jnp.int32) + threshold_mask = jnp.r_[distinct_value_indices, 1] + + # accumulate the true positives with decreasing threshold + tps_ = jnp.cumsum(groups) + fps_ = 1 + jnp.arange(groups.size) - tps_ + + # mask out the values that are not distinct + tps = jnp.sort(tps_ * threshold_mask) + fps = jnp.sort(fps_ * threshold_mask) + tps = jnp.r_[0, tps] + fps = jnp.r_[0, fps] + fpr = fps / fps[-1] + tpr = tps / tps[-1] + area = jnp.trapz(tpr, fpr) + return area + + +def expr_auroc_over_groups(expr, groups): + """Computes AUROC for each group separately.""" + auroc = jnp.zeros((groups.max() + 1, expr.shape[1])) + frac_nz = jnp.zeros((groups.max() + 1, expr.shape[1])) + + for group in range(groups.max() + 1): + auroc = auroc.at[group, :].set(jit_auroc(expr, groups == group)) + frac_nz = frac_nz.at[group, :].set(frac_nonzero(expr[groups == group, :])) + + return auroc, frac_nz diff --git a/build/lib/snapseed/degenes copy.py b/build/lib/snapseed/degenes copy.py new file mode 100644 index 0000000..4d42df2 --- /dev/null +++ b/build/lib/snapseed/degenes copy.py @@ -0,0 +1,255 @@ +import math + +import numpy as np +import pandas as pd + +import scanpy as sc +from .auroc import annotate_snap + +from snapseed.utils import read_yaml + + +def annotate_degenes( + adata, + marker_dict, + group_name, + layer=None, + level=None, + ): + """ + Annotate cell types based on differentially expressed (DE) marker genes. + + Parameters + ---------- + adata + AnnData object + marker_dict + Dict with marker genes for each celltype + group_name + Name of the column in adata.obs that contains the cluster labels + layer + Layer in adata to use for expression + """ + # level_name = "level_" + str(level) + + # TODO magic way for adata only has one cluster + # fixed + # if len(adata.obs[group_name].unique()) <= 1: + # # 1st way + # # assign_df = annotate_snap( + # # adata, marker_dict, group_name, layer=layer + # # ) + # # 2nd way + # assign_df = pd.DataFrame({'class':['na'], 'score':[np.nan], 'expr':[1]}) + # assign_df.index=adata.obs[group_name].unique() + # # 3rd way + # # assign_df=pd.DataFrame() + # return assign_df + + + # for celltype in marker_dict['subtypes']: + # subgenes = marker_dict['subtypes'][celltype] + # if marker_dict['subtypes'][celltype]['marker_genes'] == []: + # markers_all=[] + # for i in marker_dict['subtypes']: + # for j in marker_dict['subtypes'][i]['marker_genes']: + # markers_all.append(j) + # marker_dict['subtypes'][celltype]['marker_genes'] = markers_all + + # cal max de + corr_df = get_bulk_exp(adata, group_name).astype(float).corr() + corr_df = 1 - corr_df + + if level==1: + dist_pect=10 + else: + dist_pect=2 + + ntop = math.ceil(len(adata.obs[group_name].unique())/dist_pect) + cluster_to_compair = corr_df.apply(lambda s: s.abs().nlargest(ntop).index.tolist(), axis=1).to_dict() + + result_df_zscore = pd.DataFrame(marker_dict.keys()) + result_df_zscore = result_df_zscore.rename(columns={0:'level_name'}) + + result_df_apvalue = pd.DataFrame(marker_dict.keys()) + result_df_apvalue = result_df_apvalue.rename(columns={0:'level_name'}) + + if len(adata.obs[group_name].unique()) == 1: + assign_df = pd.DataFrame(index=adata.obs[group_name].unique()) + assign_df['max_de'] = 'na' + assign_df['de_score'] = 0 + z_df = pd.DataFrame(columns=adata.obs[group_name].unique(), + index=list(marker_dict.keys())).fillna(0) + + else: + for cluster in adata.obs[group_name].unique(): + adata0 = adata.copy() + adata0.obs[group_name] = adata0.obs[group_name].astype(str) + + # sc.tl.rank_genes_groups(adata0, group_name, groups=[cluster], + # reference=cluster_to_compair[cluster], method='wilcoxon') + + adata0.obs.loc[adata0.obs[group_name].isin(cluster_to_compair[cluster]), + group_name] = 'ref' + + sc.tl.rank_genes_groups(adata0, group_name, groups=[cluster], + reference='ref', method='wilcoxon') + + wranks = wrangle_ranks_from_adata(adata0) + + z_scores=[] + adj_pvalss=[] + for i in marker_dict: + z_scores.append(wranks.loc[wranks.gene.isin(marker_dict[i]), 'z_score'].max()) + adj_pvalss.append(-np.log10(wranks.loc[wranks.gene.isin(marker_dict[i]), 'adj_pvals']).max()) + # z_scores.append(wranks.loc[wranks.gene.isin(marker_dict[i]['marker_genes']), 'z_score'].max()) + # adj_pvalss.append(-np.log10(wranks.loc[wranks.gene.isin(marker_dict[i]['marker_genes']), 'adj_pvals']).max()) + + # result_df_zscore[cluster]=[i*j for i,j in zip(z_scores,adj_pvalss)] + result_df_zscore[cluster]=z_scores + result_df_apvalue[cluster]=adj_pvalss + + z_df = result_df_zscore.set_index('level_name') + cluster2ct = z_df.idxmax().to_dict() + + has_na = False + # TODO add has_na + if has_na: + for i in cluster2ct: + if i in z_df[z_df.max(axis=1)<1].index: + cluster2ct[i] = 'na' + + assign_df = pd.DataFrame(pd.Series(cluster2ct)) + assign_df = assign_df.rename(columns={0:'max_de'}) + assign_df['de_score'] = z_df.max() + + # cal max exp + raw_bulk = get_bulk_exp(adata, group_name, 'raw') + max_exp=pd.DataFrame(index=raw_bulk.columns) + for cell in marker_dict: + if sum(raw_bulk.index.isin(marker_dict[cell]))>0: + good_markers = [i for i in marker_dict[cell] if i in raw_bulk.index] + max_exp[cell] = raw_bulk.loc[good_markers].max() + else: + max_exp[cell] = 0 + + s = max_exp.select_dtypes(include='object').columns + max_exp[s] = max_exp[s].astype("float") + max_exp_df = pd.DataFrame(max_exp.max(axis=1)) + max_exp_df = max_exp_df.rename(columns={0:'exp_score'}) + + max_exp_df['max_exp'] = max_exp.idxmax(axis=1) + + + # merge de and exp + assign_df = pd.merge(assign_df, max_exp_df, left_index=True, right_index=True) + + + + mt_results = z_df * max_exp.T.loc[z_df.index, z_df.columns] + + mt_results_df = pd.DataFrame(mt_results.max(axis=0)) + mt_results_df = mt_results_df.rename(columns={0:'mt_score'}) + + mt_results_df['max_mt'] = mt_results.idxmax(axis=0) + + assign_df = pd.merge(assign_df, mt_results_df, left_index=True, right_index=True) + + use_mt = True + na_cutoff = 0.1 + if use_mt: + classs=[] + for index,row in assign_df.iterrows(): + if row['mt_score'] <= na_cutoff: + classs.append(row['max_exp']) + else: + classs.append(row['max_mt']) + + # else: + # for index,row in assign_df.iterrows(): + # if row['mt_score'] < na_cutoff: + # if row['de_score'] > na_cutoff: + # classs.append(row['max_de']) + # elif row['exp_score'] > na_cutoff: + # classs.append(row['max_exp']) + # else: + # classs.append('na') + # else: + # classs.append(row['max_mt']) + + else: + classs=[] + # for index,row in assign_df.iterrows(): + # if row['de_score'] > 2: + # classs.append(row['max_de']) + # elif row['de_score'] > 1 and row['exp_score'] < 2: + # classs.append(row['max_de']) + # elif row['de_score'] > 1 and row['exp_score'] > 2: + # classs.append(row['max_exp']) + # elif row['exp_score'] > 0: + # classs.append(row['max_exp']) + # else: + # classs.append('na') + for index,row in assign_df.iterrows(): + if row['de_score'] > 2: + classs.append(row['max_de']) + elif row['de_score'] > 1: + if row['exp_score'] > 2: + classs.append(row['max_exp']) + else: + classs.append(row['max_de']) + elif row['de_score'] > 0: + if row['exp_score'] > 0.5: + classs.append(row['max_exp']) + else: + classs.append('na') + else: + classs.append('na') + + assign_df['class'] = classs + + # TODO magic way to avoid get_annot_df error + assign_df['expr'] = 1 + + return assign_df + + # adata.obs[level_name] = adata.obs[group_name].map(cluster2ct) + # return adata + +def wrangle_ranks_from_adata(adata): + """ + Wrangle results from the ranked_genes_groups function of Scanpy. + """ + # Get number of top ranked genes per groups + nb_marker = len(adata.uns['rank_genes_groups']['names']) + # Wrangle results into a table (pandas dataframe) + top_score = pd.DataFrame(adata.uns['rank_genes_groups']['scores']).loc[:nb_marker] + top_adjpval = pd.DataFrame(adata.uns['rank_genes_groups']['pvals_adj']).loc[:nb_marker] + top_gene = pd.DataFrame(adata.uns['rank_genes_groups']['names']).loc[:nb_marker] + marker_df = pd.DataFrame() + # Order values + for i in top_score.columns: + concat = pd.concat([top_score[[str(i)]], top_adjpval[str(i)], top_gene[[str(i)]]], axis=1, ignore_index=True) + concat['cluster_number'] = i + col = list(concat.columns) + col[0], col[1], col[-2] = 'z_score', 'adj_pvals', 'gene' + concat.columns = col + marker_df = marker_df.append(concat) + return marker_df + +def get_bulk_exp(adata, bulk_labels, layer='var'): + if layer=='raw': + res = pd.DataFrame(columns=adata.raw.var_names, index=adata.obs[bulk_labels].cat.categories) + else: + res = pd.DataFrame(columns=adata.var_names, index=adata.obs[bulk_labels].cat.categories) + + for clust in adata.obs[bulk_labels].cat.categories: + if layer=='raw': + res.loc[clust] = adata[adata.obs[bulk_labels].isin([clust]),:].raw.X.mean(0) + else: + res.loc[clust] = adata[adata.obs[bulk_labels].isin([clust]),:].X.mean(0) + + res.index=adata.obs[bulk_labels].cat.categories + + return res.T + diff --git a/build/lib/snapseed/degenes.py b/build/lib/snapseed/degenes.py new file mode 100644 index 0000000..6839829 --- /dev/null +++ b/build/lib/snapseed/degenes.py @@ -0,0 +1,273 @@ +import math +from functools import partial + +import numpy as np +import pandas as pd +from sklearn import preprocessing + +import jax +from jax import numpy as jnp + +import scanpy as sc + +from .auroc import annotate_snap + +# from snapseed.utils import read_yaml +from .utils import read_yaml, dict_to_binary, frac_nonzero, masked_max, match, to_dense + + +def annotate_degenes( + adata, + marker_dict, + group_name, + layer=None, + level=None, + auc_weight=0.5, + expr_weight=0.5 + ): + """ + Annotate cell types based on differentially expressed (DE) marker genes. + + Parameters + ---------- + adata + AnnData object + marker_dict + Dict with marker genes for each celltype + group_name + Name of the column in adata.obs that contains the cluster labels + layer + Layer in adata to use for expression + """ + # level_name = "level_" + str(level) + + # TODO magic way for adata only has one cluster + # fixed + # if len(adata.obs[group_name].unique()) <= 1: + # # 1st way + # # assign_df = annotate_snap( + # # adata, marker_dict, group_name, layer=layer + # # ) + # # 2nd way + # assign_df = pd.DataFrame({'class':['na'], 'score':[np.nan], 'expr':[1]}) + # assign_df.index=adata.obs[group_name].unique() + # # 3rd way + # # assign_df=pd.DataFrame() + # return assign_df + + + # for celltype in marker_dict['subtypes']: + # subgenes = marker_dict['subtypes'][celltype] + # if marker_dict['subtypes'][celltype]['marker_genes'] == []: + # markers_all=[] + # for i in marker_dict['subtypes']: + # for j in marker_dict['subtypes'][i]['marker_genes']: + # markers_all.append(j) + # marker_dict['subtypes'][celltype]['marker_genes'] = markers_all + + # cal max de + corr_df = get_bulk_exp(adata, group_name).astype(float).corr() + corr_df = 1 - corr_df + + if level==1: + dist_pect=10 + else: + dist_pect=2 + + ntop = math.ceil(len(adata.obs[group_name].unique())/dist_pect) + cluster_to_compair = corr_df.apply(lambda s: s.abs().nlargest(ntop).index.tolist(), axis=1).to_dict() + for i in cluster_to_compair: + if i in cluster_to_compair[i]: + cluster_to_compair[i].remove(i) + + # Reformat marker_dict into binary matrix + marker_mat = dict_to_binary(marker_dict) + + # Compute AUROC and fraction nonzero for marker features + features = marker_mat.columns + + le = preprocessing.LabelEncoder() + le.fit(adata.obs[group_name]) + cell_groups = jnp.array(le.transform(adata.obs[group_name])) + + groups = le.classes_ + + expr, features = get_expr(adata, features=features, layer=layer) + + aurocs = jnp.zeros((len(groups), len(features))) + frac_nzs = jnp.zeros((len(groups), len(features))) + allgroups=[] + + aurocs = pd.DataFrame(aurocs).T + aurocs.columns=groups + aurocs.index = features + + frac_nzs = pd.DataFrame(frac_nzs).T + frac_nzs.columns=groups + frac_nzs.index = features + + for group in groups: + + group = str(group) + + adata0 = adata.copy() + adata0.obs[group_name] = adata0.obs[group_name].astype(str) + + if group not in cluster_to_compair: + adata0 = adata0[adata0.obs[group_name].isin([group])] + + else: + adata0.obs.loc[adata0.obs[group_name].isin(cluster_to_compair[group]), + group_name] = 'ref' + adata0 = adata0[adata0.obs[group_name].isin(['ref', group])] + + metric = auc_expr(adata0, group_name, features=features) + aurocs[group] = metric['auroc'][metric['groups']==group][0] + frac_nzs[group] = metric['frac_nonzero'][metric['groups']==group][0] + allgroups.append(group) + + aurocs=aurocs.fillna(0.5) ### MAGIC TODO MAY PROBLEM + ### The PDAC Hwang_NatGenet_2022 GSE202051_003 sample + ### Level 2 mesenchyme only have 1 cluster lead error + ### KeyError: "None of [Float64Index([nan], dtype='float64')] are in the [columns]" + + metrics={'frac_nonzero':frac_nzs, + 'auroc':aurocs, + 'features':features, + 'groups':groups} + + + marker_mat = marker_mat.loc[:, metrics["features"]] + + auc_max = pd.DataFrame() + for i in marker_dict: + auc_max[i] = aurocs.loc[[i for i in marker_dict[i] if i in aurocs.index],:].max() + + expr_max = pd.DataFrame() + for i in marker_dict: + expr_max[i] = frac_nzs.loc[[i for i in marker_dict[i] if i in frac_nzs.index],:].max() + + # Combine metrics + assignment_scores = (auc_weight * auc_max + expr_weight * expr_max) / ( + auc_weight + expr_weight + ) + + assign_class = assignment_scores.idxmax(1) + + assign_df = pd.DataFrame( + { + "class": assign_class, + "score": np.diag(assignment_scores[assign_class]), + "auc": np.diag(auc_max[assign_class]), + "expr": np.diag(expr_max[assign_class]), + }, + index=metrics["groups"], + ) + + return assign_df + + # adata.obs[level_name] = adata.obs[group_name].map(cluster2ct) + # return adata + +def get_bulk_exp(adata, bulk_labels, layer='var'): + if layer=='raw': + res = pd.DataFrame(columns=adata.raw.var_names, index=adata.obs[bulk_labels].cat.categories) + else: + res = pd.DataFrame(columns=adata.var_names, index=adata.obs[bulk_labels].cat.categories) + + for clust in adata.obs[bulk_labels].cat.categories: + if layer=='raw': + res.loc[clust] = adata[adata.obs[bulk_labels].isin([clust]),:].raw.X.mean(0) + else: + res.loc[clust] = adata[adata.obs[bulk_labels].isin([clust]),:].X.mean(0) + + res.index=adata.obs[bulk_labels].cat.categories + + return res.T + +def get_expr(adata, features=None, layer=None): + """Get expression matrix from adata object""" + + if layer == 'raw' or layer == None: + adata = adata.raw.to_adata() + + if features is not None: + # intersect with adata features + features = list(set(features) & set(adata.var_names)) + adata = adata[:, match(features, adata.var_names.tolist())] + + if layer == 'raw' or layer == None: + expr = jnp.array(to_dense(adata.X)) + + else: + expr = jnp.array(to_dense(adata.layers[layer])) + + return expr, features + +def auc_expr(adata, group_name, features=None, layer=None): + """Computes AUROC and fraction nonzero for each gene in an adata object.""" + # Turn string groups into integers + le = preprocessing.LabelEncoder() + le.fit(adata.obs[group_name]) + # Compute AUROC and fraction nonzero + cell_groups = jnp.array(le.transform(adata.obs[group_name])) + + groups = le.classes_ + + expr, features = get_expr(adata, features=features, layer=layer) + auroc, frac_nonzero = expr_auroc_over_groups(expr, cell_groups, groups) + + return dict( + frac_nonzero=frac_nonzero, + auroc=auroc, + features=features, + groups=le.classes_, + ) + + +@jax.jit +@partial(jax.vmap, in_axes=[1, None]) +def jit_auroc(x, groups): + # TODO: compute frac nonzero here to avoid iterating twice + + # sort scores and corresponding truth values + desc_score_indices = jnp.argsort(x)[::-1] + x = x[desc_score_indices] + groups = groups[desc_score_indices] + + # x typically has many tied values. Here we extract + # the indices associated with the distinct values. We also + # concatenate a value for the end of the curve. + distinct_value_indices = jnp.array(jnp.diff(x) != 0, dtype=jnp.int32) + threshold_mask = jnp.r_[distinct_value_indices, 1] + + # accumulate the true positives with decreasing threshold + tps_ = jnp.cumsum(groups) + fps_ = 1 + jnp.arange(groups.size) - tps_ + + # mask out the values that are not distinct + tps = jnp.sort(tps_ * threshold_mask) + fps = jnp.sort(fps_ * threshold_mask) + tps = jnp.r_[0, tps] + fps = jnp.r_[0, fps] + fpr = fps / fps[-1] + tpr = tps / tps[-1] + area = jnp.trapz(tpr, fpr) + return area + + +def expr_auroc_over_groups(expr, cell_groups, groups): + """Computes AUROC for each group separately.""" + auroc = jnp.zeros((len(groups), expr.shape[1])) + frac_nz = jnp.zeros((len(groups), expr.shape[1])) + + for group in groups: + if group == 'ref': + group = 0 + else: + group = int(group) + auroc = auroc.at[group, :].set(jit_auroc(expr, cell_groups == group)) + frac_nz = frac_nz.at[group, :].set(frac_nonzero(expr[cell_groups == group, :])) + + return auroc, frac_nz + diff --git a/build/lib/snapseed/trinarize.py b/build/lib/snapseed/trinarize.py new file mode 100644 index 0000000..3515b56 --- /dev/null +++ b/build/lib/snapseed/trinarize.py @@ -0,0 +1,152 @@ +import numba +import jax +from jax import numpy as jnp +import pandas as pd +import numpy as np + +from functools import partial +from sklearn import preprocessing + +from jax.scipy.special import gammaln, betainc, betaln + +from .utils import dict_to_binary, get_expr + + +def annotate_cytograph(adata, marker_dict, group_name, layer=None, f=0.2): + """ + Annotate clusters based on trinarization of marker gene expression. + + Parameters + ---------- + adata + AnnData object + marker_dict + Dict with marker genes for each celltype + group_name + Name of the column in adata.obs that contains the cluster labels + layer + Layer in adata to use for expression + """ + # Reformat marker_dict into binary matrix + marker_mat = dict_to_binary(marker_dict) + + features = marker_mat.columns + marker_mat = marker_mat.loc[:, features] + + # Compute trinaries + metrics = trinarize(adata, group_name, features=features, layer=layer, f=0.2) + + annot_probs = get_annot_probs(np.array(metrics["trinaries"]), marker_mat.values) + + annot_dict = {} + annot_scores = np.zeros(annot_probs.shape[1]) + for i in range(annot_probs.shape[1]): + annot_dict[metrics["groups"][i]] = marker_mat.index[ + np.argmax(annot_probs[:, i]) + ] + annot_scores[i] = np.max(annot_probs[:, i]) + + annot_tags = [ + tag if isinstance(tag, str) else ";".join(tag) for tag in annot_dict.values() + ] + + assign_df = pd.DataFrame( + {"class": annot_tags, "score": annot_scores}, + index=annot_dict.keys(), + ) + return assign_df + + +def trinarize(adata, group_name, features=None, layer=None, f=0.2): + """Compute the trinaries for each marker gene.""" + # Turn string groups into integers + le = preprocessing.LabelEncoder() + le.fit(adata.obs[group_name]) + + # Get groups and expression + groups = jnp.array(le.transform(adata.obs[group_name])) + n_groups = int(groups.max() + 1) + expr, features = get_expr(adata, features=features, layer=layer) + + # Trinatize + trinaries = betabinomial_trinarize_array(expr, groups, n_groups, f) + + return dict( + trinaries=trinaries, + features=features, + groups=le.classes_, + ) + + +@numba.jit +def get_annot_probs(trinaries, marker_array): + """Compute the annotaion probability foea each cell type.""" + group_probs = np.zeros((marker_array.shape[0], trinaries.shape[1])) + for i in np.arange(trinaries.shape[1]): + for j in np.arange(marker_array.shape[0]): + marker_trinaries = trinaries[:, i][np.nonzero(marker_array[j, :])] + if marker_trinaries.shape[0] == 0: + group_probs[j, i] = 0 + else: + group_probs[j, i] = np.mean(marker_trinaries) + return group_probs + + +@partial(jax.jit, static_argnames=["n_groups", "f"]) +@partial(jax.vmap, in_axes=[1, None, None, None]) +def betabinomial_trinarize_array(x, groups, n_groups, f): + """ + Trinarize a vector, grouped by groups, using a beta binomial model + Parameters + ---------- + x + The input expression vector. + groups + Group labels. + + Returns + ------- + ps + The posterior probability of xession in at least a fraction f + """ + x = jnp.round(jnp.array(x)) + n_by_group = jnp.bincount(groups, length=n_groups) + k_by_group = jnp.zeros(n_groups) + for g in range(n_groups): + group_mask = jnp.array(groups == g, dtype=jnp.int32) + k_by_group = k_by_group.at[g].set(jnp.count_nonzero(x * group_mask)) + ps = p_half(k_by_group, n_by_group, f) + return ps + + +@jax.jit +@partial(jax.vmap, in_axes=[0, 0, None]) +def p_half(k: int, n: int, f: float) -> float: + """ + Return probability that at least half the cells express, if we have observed k of n cells expressing: + + p|k,n = 1-(betainc(1+k, 1-k+n, f)*gamma(2+n)/(gamma(1+k)*gamma(1-k+n))/beta(1+k, 1-k+n) + + Parameters + ---------- + k + Number of observed positive cells + n + Total number of cells + """ + # These are the prior hyperparameters beta(a,b) + a = 1.5 + b = 2 + # We really want to calculate this: + # p = 1-(betainc(a+k, b-k+n, 0.5)*beta(a+k, b-k+n)*gamma(a+b+n)/(gamma(a+k)*gamma(b-k+n))) + # + # But it's numerically unstable, so we need to work on log scale (and special-case the incomplete beta) + incb = betainc(a + k, b - k + n, f) + p = 1.0 - jnp.exp( + jnp.log(incb) + + betaln(a + k, b - k + n) + + gammaln(a + b + n) + - gammaln(a + k) + - gammaln(b - k + n) + ) + return p diff --git a/build/lib/snapseed/utils.py b/build/lib/snapseed/utils.py new file mode 100644 index 0000000..dbcd84e --- /dev/null +++ b/build/lib/snapseed/utils.py @@ -0,0 +1,113 @@ +import yaml +import pandas as pd +import numba + +import jax +from jax import numpy as jnp +from jax.experimental.sparse import BCOO + +from functools import partial +from sklearn.metrics import roc_auc_score + +import scanpy as sc + +to_dense = lambda x: x.toarray() if hasattr(x, "toarray") else x + +def to_jax_array(x): + """Turn matrix to jax array.""" + if hasattr(x, "todense"): + # Currently this is not supported for really large matrices + # return BCOO.from_scipy_sparse(x).update_layout(n_batch=1, on_inefficient=None) + return jnp.asarray(x.todense()) + else: + return jnp.asarray(x) + + +@jax.jit +@partial(jax.vmap, in_axes=[None, 0]) +def masked_max(x, mask): + return jnp.max(x * mask, axis=1) + + +@jax.jit +@partial(jax.vmap, in_axes=[None, 0]) +def masked_mean(x, mask): + return jnp.sum(x * mask, axis=1) / jnp.sum(mask) + + +def frac_nonzero(x, axis=0): + return jnp.mean(x > 0, axis=axis) + + +jit_frac_nonzero = jax.jit(frac_nonzero) + + +def dict_to_binary(d): + df = pd.concat( + [pd.Series(v, name=k).astype(str) for k, v in d.items()], + axis=1, + ) + marker_mat = pd.get_dummies(df.stack()).groupby(level=1).sum().clip(upper=1) + return marker_mat + + +@numba.jit +def match(a, b): + return [b.index(x) if x in b else None for x in a] + + +def get_expr(adata, features=None, layer=None): + """Get expression matrix from adata object""" + if features is not None: + # intersect with adata features + features = list(set(features) & set(adata.var_names)) + adata = adata[ + :, + match( + numba.typed.List(features), numba.typed.List(adata.var_names.tolist()) + ), + ] + else: + features = adata.var_names.copy().tolist() + + if layer is not None: + expr = to_jax_array(adata.layers[layer]) + else: + expr = to_jax_array(adata.X) + + return expr, features + + +def get_markers(x): + return {n: v["marker_genes"] for n, v in x.items()} + + +def read_yaml(file): + with open(file, "r") as f: + marker_dict = yaml.safe_load(f) + return marker_dict + + +def get_annot_df(x, group_name, min_expr=0.1): + # Get valid annots from each level + annot_list = [] + for k, v in x.items(): + annot = v.set_index(group_name)["class"] + if min_expr > 0: + expr = v.set_index(group_name)["expr"] + annot = annot[expr > min_expr] + annot_list.append(annot) + # Concat annots + annot_df = pd.concat(annot_list, axis=1) + # Rename cols to levels + annot_df.columns = [str(i) for i in x.keys()] + return annot_df + + +def matrix_to_long_df(x, features, groups): + """Converts a matrix to a long dataframe""" + df = pd.DataFrame(x, index=groups, columns=features) + df = df.stack().reset_index() + df.columns = ["group", "feature", "value"] + return df + \ No newline at end of file diff --git a/snapseed.egg-info/PKG-INFO b/snapseed.egg-info/PKG-INFO new file mode 100644 index 0000000..2547740 --- /dev/null +++ b/snapseed.egg-info/PKG-INFO @@ -0,0 +1,40 @@ +Metadata-Version: 2.1 +Name: snapseed +Version: 0.1.0 +Summary: Superfast hierarchical annotation of large single-cell datasets +Author: Jonas Simon Fleck +Author-email: jonas.simon.fleck@gmail.com +License: MIT +Requires-Python: >=3.9 +Description-Content-Type: text/markdown + +# snapseed + +Snapseed annotates single-cell datasets based on manually defined sets of marker genes for individual cell types or cell type hierarchies. It is fast and simple to accelerate annotation of very large datasets. + + +## Quick start + +```python +import snapseed as snap +from snapseed.utils import read_yaml + +# Read in the marker genes +marker_genes = read_yaml("marker_genes.yaml") + +# Annotate anndata objects +snap.annotate( + adata, + marker_genes, + group_name="clusters", + layer="lognorm", +) + +# Or for more complex hierarchies +snap.annotate_hierarchy( + adata, + marker_genes, + group_name="clusters", + layer="lognorm", +) +``` diff --git a/snapseed.egg-info/SOURCES.txt b/snapseed.egg-info/SOURCES.txt new file mode 100644 index 0000000..0434dd2 --- /dev/null +++ b/snapseed.egg-info/SOURCES.txt @@ -0,0 +1,14 @@ +README.md +setup.py +snapseed/__init__.py +snapseed/annotate.py +snapseed/auroc.py +snapseed/degenes copy.py +snapseed/degenes.py +snapseed/trinarize.py +snapseed/utils.py +snapseed.egg-info/PKG-INFO +snapseed.egg-info/SOURCES.txt +snapseed.egg-info/dependency_links.txt +snapseed.egg-info/requires.txt +snapseed.egg-info/top_level.txt \ No newline at end of file diff --git a/snapseed.egg-info/dependency_links.txt b/snapseed.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/snapseed.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/snapseed.egg-info/requires.txt b/snapseed.egg-info/requires.txt new file mode 100644 index 0000000..412f00d --- /dev/null +++ b/snapseed.egg-info/requires.txt @@ -0,0 +1,7 @@ +jax>=0.2.26 +jaxlib>=0.1.75 +numpy>=1.18.5 +numba>=0.56.4 +scanpy>=1.9.1 +scikit-learn>=1.1.3 +pandas>=1.5.2 diff --git a/snapseed.egg-info/top_level.txt b/snapseed.egg-info/top_level.txt new file mode 100644 index 0000000..b7c5403 --- /dev/null +++ b/snapseed.egg-info/top_level.txt @@ -0,0 +1 @@ +snapseed diff --git a/snapseed/annotate.py b/snapseed/annotate.py index 96dc0e5..2c2a1ab 100644 --- a/snapseed/annotate.py +++ b/snapseed/annotate.py @@ -2,6 +2,7 @@ from .trinarize import annotate_cytograph from .auroc import annotate_snap +from .degenes import annotate_degenes from .utils import get_markers, get_annot_df @@ -58,7 +59,7 @@ def annotate_levels( level += 1 level_name = "level_" + str(level) marker_dict = get_markers(marker_hierarchy) - assignments = annotate(adata, marker_dict, group_name, method=method, layer=layer) + assignments = annotate(adata, marker_dict, group_name, method=method, layer=layer, level=level) if assignment_levels is None: assignment_levels = {} @@ -71,7 +72,9 @@ def annotate_levels( ) for subtype in assignments["class"].unique(): - + if subtype == 'na': + continue + if "subtypes" not in marker_hierarchy[subtype].keys(): continue @@ -95,7 +98,7 @@ def annotate_levels( return assignment_levels -def annotate(adata, marker_dict, group_name, method="auroc", layer=None, **kwargs): +def annotate(adata, marker_dict, group_name, method="auroc", layer=None, level=None, **kwargs): """ Annotate clusters based on a manually defined cell type markers. @@ -123,8 +126,12 @@ def annotate(adata, marker_dict, group_name, method="auroc", layer=None, **kwarg assignments = annotate_cytograph( adata, marker_dict, group_name, layer=layer, **kwargs ) + elif method == "degenes": + assignments = annotate_degenes( + adata, marker_dict, group_name, layer=layer, level=None, **kwargs + ) else: - raise ValueError("Unknown annotation method.") + raise ValueError("Unknown annotation method.") # Join cluster-level results with adata assignments = assignments.reset_index(names=group_name) return assignments diff --git a/snapseed/degenes copy.py b/snapseed/degenes copy.py new file mode 100644 index 0000000..4d42df2 --- /dev/null +++ b/snapseed/degenes copy.py @@ -0,0 +1,255 @@ +import math + +import numpy as np +import pandas as pd + +import scanpy as sc +from .auroc import annotate_snap + +from snapseed.utils import read_yaml + + +def annotate_degenes( + adata, + marker_dict, + group_name, + layer=None, + level=None, + ): + """ + Annotate cell types based on differentially expressed (DE) marker genes. + + Parameters + ---------- + adata + AnnData object + marker_dict + Dict with marker genes for each celltype + group_name + Name of the column in adata.obs that contains the cluster labels + layer + Layer in adata to use for expression + """ + # level_name = "level_" + str(level) + + # TODO magic way for adata only has one cluster + # fixed + # if len(adata.obs[group_name].unique()) <= 1: + # # 1st way + # # assign_df = annotate_snap( + # # adata, marker_dict, group_name, layer=layer + # # ) + # # 2nd way + # assign_df = pd.DataFrame({'class':['na'], 'score':[np.nan], 'expr':[1]}) + # assign_df.index=adata.obs[group_name].unique() + # # 3rd way + # # assign_df=pd.DataFrame() + # return assign_df + + + # for celltype in marker_dict['subtypes']: + # subgenes = marker_dict['subtypes'][celltype] + # if marker_dict['subtypes'][celltype]['marker_genes'] == []: + # markers_all=[] + # for i in marker_dict['subtypes']: + # for j in marker_dict['subtypes'][i]['marker_genes']: + # markers_all.append(j) + # marker_dict['subtypes'][celltype]['marker_genes'] = markers_all + + # cal max de + corr_df = get_bulk_exp(adata, group_name).astype(float).corr() + corr_df = 1 - corr_df + + if level==1: + dist_pect=10 + else: + dist_pect=2 + + ntop = math.ceil(len(adata.obs[group_name].unique())/dist_pect) + cluster_to_compair = corr_df.apply(lambda s: s.abs().nlargest(ntop).index.tolist(), axis=1).to_dict() + + result_df_zscore = pd.DataFrame(marker_dict.keys()) + result_df_zscore = result_df_zscore.rename(columns={0:'level_name'}) + + result_df_apvalue = pd.DataFrame(marker_dict.keys()) + result_df_apvalue = result_df_apvalue.rename(columns={0:'level_name'}) + + if len(adata.obs[group_name].unique()) == 1: + assign_df = pd.DataFrame(index=adata.obs[group_name].unique()) + assign_df['max_de'] = 'na' + assign_df['de_score'] = 0 + z_df = pd.DataFrame(columns=adata.obs[group_name].unique(), + index=list(marker_dict.keys())).fillna(0) + + else: + for cluster in adata.obs[group_name].unique(): + adata0 = adata.copy() + adata0.obs[group_name] = adata0.obs[group_name].astype(str) + + # sc.tl.rank_genes_groups(adata0, group_name, groups=[cluster], + # reference=cluster_to_compair[cluster], method='wilcoxon') + + adata0.obs.loc[adata0.obs[group_name].isin(cluster_to_compair[cluster]), + group_name] = 'ref' + + sc.tl.rank_genes_groups(adata0, group_name, groups=[cluster], + reference='ref', method='wilcoxon') + + wranks = wrangle_ranks_from_adata(adata0) + + z_scores=[] + adj_pvalss=[] + for i in marker_dict: + z_scores.append(wranks.loc[wranks.gene.isin(marker_dict[i]), 'z_score'].max()) + adj_pvalss.append(-np.log10(wranks.loc[wranks.gene.isin(marker_dict[i]), 'adj_pvals']).max()) + # z_scores.append(wranks.loc[wranks.gene.isin(marker_dict[i]['marker_genes']), 'z_score'].max()) + # adj_pvalss.append(-np.log10(wranks.loc[wranks.gene.isin(marker_dict[i]['marker_genes']), 'adj_pvals']).max()) + + # result_df_zscore[cluster]=[i*j for i,j in zip(z_scores,adj_pvalss)] + result_df_zscore[cluster]=z_scores + result_df_apvalue[cluster]=adj_pvalss + + z_df = result_df_zscore.set_index('level_name') + cluster2ct = z_df.idxmax().to_dict() + + has_na = False + # TODO add has_na + if has_na: + for i in cluster2ct: + if i in z_df[z_df.max(axis=1)<1].index: + cluster2ct[i] = 'na' + + assign_df = pd.DataFrame(pd.Series(cluster2ct)) + assign_df = assign_df.rename(columns={0:'max_de'}) + assign_df['de_score'] = z_df.max() + + # cal max exp + raw_bulk = get_bulk_exp(adata, group_name, 'raw') + max_exp=pd.DataFrame(index=raw_bulk.columns) + for cell in marker_dict: + if sum(raw_bulk.index.isin(marker_dict[cell]))>0: + good_markers = [i for i in marker_dict[cell] if i in raw_bulk.index] + max_exp[cell] = raw_bulk.loc[good_markers].max() + else: + max_exp[cell] = 0 + + s = max_exp.select_dtypes(include='object').columns + max_exp[s] = max_exp[s].astype("float") + max_exp_df = pd.DataFrame(max_exp.max(axis=1)) + max_exp_df = max_exp_df.rename(columns={0:'exp_score'}) + + max_exp_df['max_exp'] = max_exp.idxmax(axis=1) + + + # merge de and exp + assign_df = pd.merge(assign_df, max_exp_df, left_index=True, right_index=True) + + + + mt_results = z_df * max_exp.T.loc[z_df.index, z_df.columns] + + mt_results_df = pd.DataFrame(mt_results.max(axis=0)) + mt_results_df = mt_results_df.rename(columns={0:'mt_score'}) + + mt_results_df['max_mt'] = mt_results.idxmax(axis=0) + + assign_df = pd.merge(assign_df, mt_results_df, left_index=True, right_index=True) + + use_mt = True + na_cutoff = 0.1 + if use_mt: + classs=[] + for index,row in assign_df.iterrows(): + if row['mt_score'] <= na_cutoff: + classs.append(row['max_exp']) + else: + classs.append(row['max_mt']) + + # else: + # for index,row in assign_df.iterrows(): + # if row['mt_score'] < na_cutoff: + # if row['de_score'] > na_cutoff: + # classs.append(row['max_de']) + # elif row['exp_score'] > na_cutoff: + # classs.append(row['max_exp']) + # else: + # classs.append('na') + # else: + # classs.append(row['max_mt']) + + else: + classs=[] + # for index,row in assign_df.iterrows(): + # if row['de_score'] > 2: + # classs.append(row['max_de']) + # elif row['de_score'] > 1 and row['exp_score'] < 2: + # classs.append(row['max_de']) + # elif row['de_score'] > 1 and row['exp_score'] > 2: + # classs.append(row['max_exp']) + # elif row['exp_score'] > 0: + # classs.append(row['max_exp']) + # else: + # classs.append('na') + for index,row in assign_df.iterrows(): + if row['de_score'] > 2: + classs.append(row['max_de']) + elif row['de_score'] > 1: + if row['exp_score'] > 2: + classs.append(row['max_exp']) + else: + classs.append(row['max_de']) + elif row['de_score'] > 0: + if row['exp_score'] > 0.5: + classs.append(row['max_exp']) + else: + classs.append('na') + else: + classs.append('na') + + assign_df['class'] = classs + + # TODO magic way to avoid get_annot_df error + assign_df['expr'] = 1 + + return assign_df + + # adata.obs[level_name] = adata.obs[group_name].map(cluster2ct) + # return adata + +def wrangle_ranks_from_adata(adata): + """ + Wrangle results from the ranked_genes_groups function of Scanpy. + """ + # Get number of top ranked genes per groups + nb_marker = len(adata.uns['rank_genes_groups']['names']) + # Wrangle results into a table (pandas dataframe) + top_score = pd.DataFrame(adata.uns['rank_genes_groups']['scores']).loc[:nb_marker] + top_adjpval = pd.DataFrame(adata.uns['rank_genes_groups']['pvals_adj']).loc[:nb_marker] + top_gene = pd.DataFrame(adata.uns['rank_genes_groups']['names']).loc[:nb_marker] + marker_df = pd.DataFrame() + # Order values + for i in top_score.columns: + concat = pd.concat([top_score[[str(i)]], top_adjpval[str(i)], top_gene[[str(i)]]], axis=1, ignore_index=True) + concat['cluster_number'] = i + col = list(concat.columns) + col[0], col[1], col[-2] = 'z_score', 'adj_pvals', 'gene' + concat.columns = col + marker_df = marker_df.append(concat) + return marker_df + +def get_bulk_exp(adata, bulk_labels, layer='var'): + if layer=='raw': + res = pd.DataFrame(columns=adata.raw.var_names, index=adata.obs[bulk_labels].cat.categories) + else: + res = pd.DataFrame(columns=adata.var_names, index=adata.obs[bulk_labels].cat.categories) + + for clust in adata.obs[bulk_labels].cat.categories: + if layer=='raw': + res.loc[clust] = adata[adata.obs[bulk_labels].isin([clust]),:].raw.X.mean(0) + else: + res.loc[clust] = adata[adata.obs[bulk_labels].isin([clust]),:].X.mean(0) + + res.index=adata.obs[bulk_labels].cat.categories + + return res.T + diff --git a/snapseed/degenes.py b/snapseed/degenes.py new file mode 100644 index 0000000..6839829 --- /dev/null +++ b/snapseed/degenes.py @@ -0,0 +1,273 @@ +import math +from functools import partial + +import numpy as np +import pandas as pd +from sklearn import preprocessing + +import jax +from jax import numpy as jnp + +import scanpy as sc + +from .auroc import annotate_snap + +# from snapseed.utils import read_yaml +from .utils import read_yaml, dict_to_binary, frac_nonzero, masked_max, match, to_dense + + +def annotate_degenes( + adata, + marker_dict, + group_name, + layer=None, + level=None, + auc_weight=0.5, + expr_weight=0.5 + ): + """ + Annotate cell types based on differentially expressed (DE) marker genes. + + Parameters + ---------- + adata + AnnData object + marker_dict + Dict with marker genes for each celltype + group_name + Name of the column in adata.obs that contains the cluster labels + layer + Layer in adata to use for expression + """ + # level_name = "level_" + str(level) + + # TODO magic way for adata only has one cluster + # fixed + # if len(adata.obs[group_name].unique()) <= 1: + # # 1st way + # # assign_df = annotate_snap( + # # adata, marker_dict, group_name, layer=layer + # # ) + # # 2nd way + # assign_df = pd.DataFrame({'class':['na'], 'score':[np.nan], 'expr':[1]}) + # assign_df.index=adata.obs[group_name].unique() + # # 3rd way + # # assign_df=pd.DataFrame() + # return assign_df + + + # for celltype in marker_dict['subtypes']: + # subgenes = marker_dict['subtypes'][celltype] + # if marker_dict['subtypes'][celltype]['marker_genes'] == []: + # markers_all=[] + # for i in marker_dict['subtypes']: + # for j in marker_dict['subtypes'][i]['marker_genes']: + # markers_all.append(j) + # marker_dict['subtypes'][celltype]['marker_genes'] = markers_all + + # cal max de + corr_df = get_bulk_exp(adata, group_name).astype(float).corr() + corr_df = 1 - corr_df + + if level==1: + dist_pect=10 + else: + dist_pect=2 + + ntop = math.ceil(len(adata.obs[group_name].unique())/dist_pect) + cluster_to_compair = corr_df.apply(lambda s: s.abs().nlargest(ntop).index.tolist(), axis=1).to_dict() + for i in cluster_to_compair: + if i in cluster_to_compair[i]: + cluster_to_compair[i].remove(i) + + # Reformat marker_dict into binary matrix + marker_mat = dict_to_binary(marker_dict) + + # Compute AUROC and fraction nonzero for marker features + features = marker_mat.columns + + le = preprocessing.LabelEncoder() + le.fit(adata.obs[group_name]) + cell_groups = jnp.array(le.transform(adata.obs[group_name])) + + groups = le.classes_ + + expr, features = get_expr(adata, features=features, layer=layer) + + aurocs = jnp.zeros((len(groups), len(features))) + frac_nzs = jnp.zeros((len(groups), len(features))) + allgroups=[] + + aurocs = pd.DataFrame(aurocs).T + aurocs.columns=groups + aurocs.index = features + + frac_nzs = pd.DataFrame(frac_nzs).T + frac_nzs.columns=groups + frac_nzs.index = features + + for group in groups: + + group = str(group) + + adata0 = adata.copy() + adata0.obs[group_name] = adata0.obs[group_name].astype(str) + + if group not in cluster_to_compair: + adata0 = adata0[adata0.obs[group_name].isin([group])] + + else: + adata0.obs.loc[adata0.obs[group_name].isin(cluster_to_compair[group]), + group_name] = 'ref' + adata0 = adata0[adata0.obs[group_name].isin(['ref', group])] + + metric = auc_expr(adata0, group_name, features=features) + aurocs[group] = metric['auroc'][metric['groups']==group][0] + frac_nzs[group] = metric['frac_nonzero'][metric['groups']==group][0] + allgroups.append(group) + + aurocs=aurocs.fillna(0.5) ### MAGIC TODO MAY PROBLEM + ### The PDAC Hwang_NatGenet_2022 GSE202051_003 sample + ### Level 2 mesenchyme only have 1 cluster lead error + ### KeyError: "None of [Float64Index([nan], dtype='float64')] are in the [columns]" + + metrics={'frac_nonzero':frac_nzs, + 'auroc':aurocs, + 'features':features, + 'groups':groups} + + + marker_mat = marker_mat.loc[:, metrics["features"]] + + auc_max = pd.DataFrame() + for i in marker_dict: + auc_max[i] = aurocs.loc[[i for i in marker_dict[i] if i in aurocs.index],:].max() + + expr_max = pd.DataFrame() + for i in marker_dict: + expr_max[i] = frac_nzs.loc[[i for i in marker_dict[i] if i in frac_nzs.index],:].max() + + # Combine metrics + assignment_scores = (auc_weight * auc_max + expr_weight * expr_max) / ( + auc_weight + expr_weight + ) + + assign_class = assignment_scores.idxmax(1) + + assign_df = pd.DataFrame( + { + "class": assign_class, + "score": np.diag(assignment_scores[assign_class]), + "auc": np.diag(auc_max[assign_class]), + "expr": np.diag(expr_max[assign_class]), + }, + index=metrics["groups"], + ) + + return assign_df + + # adata.obs[level_name] = adata.obs[group_name].map(cluster2ct) + # return adata + +def get_bulk_exp(adata, bulk_labels, layer='var'): + if layer=='raw': + res = pd.DataFrame(columns=adata.raw.var_names, index=adata.obs[bulk_labels].cat.categories) + else: + res = pd.DataFrame(columns=adata.var_names, index=adata.obs[bulk_labels].cat.categories) + + for clust in adata.obs[bulk_labels].cat.categories: + if layer=='raw': + res.loc[clust] = adata[adata.obs[bulk_labels].isin([clust]),:].raw.X.mean(0) + else: + res.loc[clust] = adata[adata.obs[bulk_labels].isin([clust]),:].X.mean(0) + + res.index=adata.obs[bulk_labels].cat.categories + + return res.T + +def get_expr(adata, features=None, layer=None): + """Get expression matrix from adata object""" + + if layer == 'raw' or layer == None: + adata = adata.raw.to_adata() + + if features is not None: + # intersect with adata features + features = list(set(features) & set(adata.var_names)) + adata = adata[:, match(features, adata.var_names.tolist())] + + if layer == 'raw' or layer == None: + expr = jnp.array(to_dense(adata.X)) + + else: + expr = jnp.array(to_dense(adata.layers[layer])) + + return expr, features + +def auc_expr(adata, group_name, features=None, layer=None): + """Computes AUROC and fraction nonzero for each gene in an adata object.""" + # Turn string groups into integers + le = preprocessing.LabelEncoder() + le.fit(adata.obs[group_name]) + # Compute AUROC and fraction nonzero + cell_groups = jnp.array(le.transform(adata.obs[group_name])) + + groups = le.classes_ + + expr, features = get_expr(adata, features=features, layer=layer) + auroc, frac_nonzero = expr_auroc_over_groups(expr, cell_groups, groups) + + return dict( + frac_nonzero=frac_nonzero, + auroc=auroc, + features=features, + groups=le.classes_, + ) + + +@jax.jit +@partial(jax.vmap, in_axes=[1, None]) +def jit_auroc(x, groups): + # TODO: compute frac nonzero here to avoid iterating twice + + # sort scores and corresponding truth values + desc_score_indices = jnp.argsort(x)[::-1] + x = x[desc_score_indices] + groups = groups[desc_score_indices] + + # x typically has many tied values. Here we extract + # the indices associated with the distinct values. We also + # concatenate a value for the end of the curve. + distinct_value_indices = jnp.array(jnp.diff(x) != 0, dtype=jnp.int32) + threshold_mask = jnp.r_[distinct_value_indices, 1] + + # accumulate the true positives with decreasing threshold + tps_ = jnp.cumsum(groups) + fps_ = 1 + jnp.arange(groups.size) - tps_ + + # mask out the values that are not distinct + tps = jnp.sort(tps_ * threshold_mask) + fps = jnp.sort(fps_ * threshold_mask) + tps = jnp.r_[0, tps] + fps = jnp.r_[0, fps] + fpr = fps / fps[-1] + tpr = tps / tps[-1] + area = jnp.trapz(tpr, fpr) + return area + + +def expr_auroc_over_groups(expr, cell_groups, groups): + """Computes AUROC for each group separately.""" + auroc = jnp.zeros((len(groups), expr.shape[1])) + frac_nz = jnp.zeros((len(groups), expr.shape[1])) + + for group in groups: + if group == 'ref': + group = 0 + else: + group = int(group) + auroc = auroc.at[group, :].set(jit_auroc(expr, cell_groups == group)) + frac_nz = frac_nz.at[group, :].set(frac_nonzero(expr[cell_groups == group, :])) + + return auroc, frac_nz + diff --git a/snapseed/utils.py b/snapseed/utils.py index f56cbe9..dbcd84e 100644 --- a/snapseed/utils.py +++ b/snapseed/utils.py @@ -4,6 +4,7 @@ import jax from jax import numpy as jnp +from jax.experimental.sparse import BCOO from functools import partial from sklearn.metrics import roc_auc_score @@ -12,6 +13,15 @@ to_dense = lambda x: x.toarray() if hasattr(x, "toarray") else x +def to_jax_array(x): + """Turn matrix to jax array.""" + if hasattr(x, "todense"): + # Currently this is not supported for really large matrices + # return BCOO.from_scipy_sparse(x).update_layout(n_batch=1, on_inefficient=None) + return jnp.asarray(x.todense()) + else: + return jnp.asarray(x) + @jax.jit @partial(jax.vmap, in_axes=[None, 0]) @@ -20,10 +30,18 @@ def masked_max(x, mask): @jax.jit +@partial(jax.vmap, in_axes=[None, 0]) +def masked_mean(x, mask): + return jnp.sum(x * mask, axis=1) / jnp.sum(mask) + + def frac_nonzero(x, axis=0): return jnp.mean(x > 0, axis=axis) +jit_frac_nonzero = jax.jit(frac_nonzero) + + def dict_to_binary(d): df = pd.concat( [pd.Series(v, name=k).astype(str) for k, v in d.items()], @@ -43,12 +61,19 @@ def get_expr(adata, features=None, layer=None): if features is not None: # intersect with adata features features = list(set(features) & set(adata.var_names)) - adata = adata[:, match(features, adata.var_names.tolist())] + adata = adata[ + :, + match( + numba.typed.List(features), numba.typed.List(adata.var_names.tolist()) + ), + ] + else: + features = adata.var_names.copy().tolist() if layer is not None: - expr = jnp.array(to_dense(adata.layers[layer])) + expr = to_jax_array(adata.layers[layer]) else: - expr = jnp.array(to_dense(adata.X)) + expr = to_jax_array(adata.X) return expr, features @@ -63,7 +88,7 @@ def read_yaml(file): return marker_dict -def get_annot_df(x, group_name, min_expr=0.2): +def get_annot_df(x, group_name, min_expr=0.1): # Get valid annots from each level annot_list = [] for k, v in x.items(): @@ -77,3 +102,12 @@ def get_annot_df(x, group_name, min_expr=0.2): # Rename cols to levels annot_df.columns = [str(i) for i in x.keys()] return annot_df + + +def matrix_to_long_df(x, features, groups): + """Converts a matrix to a long dataframe""" + df = pd.DataFrame(x, index=groups, columns=features) + df = df.stack().reset_index() + df.columns = ["group", "feature", "value"] + return df + \ No newline at end of file