From 0cf3501e1dd000a610aa0174d60e6637f3bb89eb Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 29 Sep 2025 13:55:03 +0200 Subject: [PATCH 01/14] update(): remove redundant code, fix some bugs and add tests --- src/mudata/_core/mudata.py | 89 +++---------- tests/test_update.py | 264 ++++++++++++++++++++++++++++++------- 2 files changed, 232 insertions(+), 121 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index f5b628a..f231af5 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -682,36 +682,16 @@ def _update_attr( data_mod[colname] = col.astype(np.uint32) if len(data_global.columns) > 0: - # TODO: if there were intersecting attrnames between modalities, - # this will increase the size of the index - # Should we use attrmap to figure the index out? - # - if not attr_intersecting: + if not attr_intersecting or axis == (1 - self.axis) or self.axis == -1: data_mod = data_mod.join(data_global, how="left", sort=False) else: # In order to preserve the order of the index, instead, - # perform a join based on (index, attrmap value) pairs. - (col_index,) = self._find_unique_colnames(attr, 1) - data_mod = data_mod.rename_axis(col_index, axis=0).reset_index() - - data_global = data_global.rename_axis(col_index, axis=0).reset_index() - for mod in self.mod.keys(): - # Only add mapping for modalities that exist in attrmap - if mod in getattr(self, attr + "map"): - data_global[mod + ":" + rowcol] = getattr(self, attr + "map")[mod] - attrmap_columns = [ - mod + ":" + rowcol - for mod in self.mod.keys() - if mod in getattr(self, attr + "map") - ] - - data_mod = data_mod.merge( - data_global, on=[col_index, *attrmap_columns], how="left", sort=False - ) - - # Restore the index and remove the helper column - data_mod = data_mod.set_index(col_index).rename_axis(None, axis=0) - data_global = data_global.set_index(col_index).rename_axis(None, axis=0) + # perform a join based on unique index. + data_mod = _make_index_unique(data_mod) + data_global = _make_index_unique(data_global) + data_mod = data_global.join(data_mod, how="left", sort=False) + data_mod = _restore_index(data_mod) + data_global = _restore_index(data_global) # # General case: with duplicates and/or intersections # @@ -782,59 +762,24 @@ def _update_attr( data_mod.reset_index(level=list(range(1, data_mod.index.nlevels)), inplace=True) data_mod.index.set_names(None, inplace=True) + if data_global.shape[0] > 0: + # reorder new index to conform to the old index as much as possible + kept_idx = data_global.index.isin(data_mod.index) + data_mod = data_mod.loc[ + data_global.index[kept_idx].append(data_global.index[~kept_idx]), : + ] + # get adata positions and remove columns from the data frame mdict = {} for m in self.mod.keys(): colname = m + ":" + rowcol mdict[m] = data_mod[colname].to_numpy() - # data_mod.drop(colname, axis=1, inplace=True) - - # Add data from global .obs/.var columns # This might reduce the size of .obs/.var if observations/variables were removed - if getattr(self, attr).index.is_unique: - # There are no new values in the index - # Original index is present in data_global - attr_reindexed = getattr(self, attr).reindex(index=data_mod.index, copy=False) - else: - # Reindexing won't work with duplicated labels: - # cannot reindex on an axis with duplicate labels. - # Use attrmap to resolve it. - - # TODO: might be possible to refactor to memoize it - # if it has already been done in the same ._update_attr() - col_index, col_range = self._find_unique_colnames(attr, 2) - - # copy is made here - data_mod = data_mod.rename_axis(col_index, axis=0).reset_index() - - data_global[col_range] = np.arange(len(data_global)) - - for mod in self.mod.keys(): - if mod in getattr(self, attr + "map"): - data_global[mod + ":" + rowcol] = getattr(self, attr + "map")[mod] - attrmap_columns = [ - mod + ":" + rowcol for mod in self.mod.keys() if mod in getattr(self, attr + "map") - ] - - data_mod = data_mod.merge(data_global, on=attrmap_columns, how="left", sort=False) - - index_selection = data_mod[col_range].values - - data_mod.drop(col_range, axis=1, inplace=True) - data_global.drop(col_range, axis=1, inplace=True) - - # Restore the index and remove the helper column - data_mod = data_mod.set_index(col_index).rename_axis(None, axis=0) - attr_reindexed = getattr(self, attr).iloc[index_selection] - attr_reindexed.index = data_mod.index - - # Clean up - for colname in (mod + "+" + rowcol for mod in self.mod.keys()): - data_mod.drop(colname, axis=1, inplace=True, errors="ignore") + data_mod.drop(colname, axis=1, inplace=True) setattr( self, "_" + attr, - attr_reindexed, + data_mod, ) # Update .obsm/.varm @@ -844,7 +789,7 @@ def _update_attr( for mod, mapping in mdict.items(): attrm[mod] = mapping > 0 - now_index = getattr(self, attr).index + now_index = data_mod.index if len(prev_index) == 0: # New object diff --git a/tests/test_update.py b/tests/test_update.py index ea27d9a..1823a3c 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -1,24 +1,24 @@ import unittest +from functools import reduce import numpy as np import pytest from anndata import AnnData -from mudata import MuData +from mudata import MuData, set_options @pytest.fixture() -def mdata(request, obs_n, obs_across, obs_mod): - # Generate unique, intersecting, and joint observations by default +def modalities(request, obs_n, obs_across, obs_mod): + n_mod = 3 + mods = dict() np.random.seed(100) - mod1 = AnnData(X=np.random.normal(size=3000).reshape(-1, 10)) - mod2 = AnnData(X=np.random.normal(size=1000).reshape(-1, 10)) - - mods = {"mod1": mod1, "mod2": mod2} - # Make var_names different in different modalities - for m in ["mod1", "mod2"]: - mods[m].obs_names = [f"obs{i}" for i in range(mods[m].n_obs)] - mods[m].var_names = [f"{m}_var{i}" for i in range(mods[m].n_vars)] + for i in range(n_mod): + i1 = i + 1 + m = f"mod{i1}" + mods[m] = AnnData(X=np.random.normal(size=3000 * i1).reshape(-1, 10 * i1)) + mods[m].obs["mod"] = m + mods[m].var["mod"] = m mods[m].obs["min_count"] = mods[m].X.min(axis=1) if obs_n: @@ -36,17 +36,20 @@ def mdata(request, obs_n, obs_across, obs_mod): if obs_mod == "duplicated": for m in ["mod1", "mod2"]: # Index does not support mutable operations - obs_names = mods[m].obs_names.to_numpy() + obs_names = mods[m].obs_names.values.copy() obs_names[1] = obs_names[0] mods[m].obs_names = obs_names - elif ( - obs_mod == "extreme_duplicated" - ): # integer overflow: https://github.com/scverse/mudata/issues/107 + elif obs_mod == "extreme_duplicated": # integer overflow: https://github.com/scverse/mudata/issues/107 obs_names = mods["mod1"].obs_names.to_numpy() obs_names[:-1] = obs_names[0] mods["mod1"].obs_names = obs_names - mdata = MuData(mods) + return mods + + +@pytest.fixture() +def mdata(modalities): + mdata = MuData(modalities) batches = np.random.choice(["a", "b", "c"], size=mdata.shape[0], replace=True) mdata.obs["batch"] = batches @@ -54,47 +57,209 @@ def mdata(request, obs_n, obs_across, obs_mod): return mdata -@pytest.fixture() -def modalities(request, obs_n, obs_across, obs_mod): - n_mod = 3 - mods = dict() - np.random.seed(100) - for i in range(n_mod): - i1 = i + 1 - m = f"mod{i1}" - mods[m] = AnnData(X=np.random.normal(size=3000 * i1).reshape(-1, 10 * i1)) - mods[m].obs["mod"] = m - mods[m].var["mod"] = m +@pytest.mark.usefixtures("filepath_h5mu") +@pytest.mark.parametrize("axis", [0, 1]) +class TestMuData: + @pytest.fixture(autouse=True) + def new_update(self): + set_options(pull_on_update=False) + yield + set_options(pull_on_update=None) - if obs_n: - if obs_n == "disjoint": - mod2_which_obs = np.random.choice( - mods["mod2"].obs_names, size=mods["mod2"].n_obs // 2, replace=False + @pytest.mark.parametrize("obs_mod", ["unique"]) + @pytest.mark.parametrize("obs_across", ["intersecting"]) + @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) + def test_update_simple(self, modalities, axis): + """ + Update should work when + - obs_names are the same across modalities, + - var_names are unique to each modality + """ + attr = "obs" if axis == 0 else "var" + oattr = "var" if axis == 0 else "obs" + for m, mod in modalities.items(): + setattr(mod, f"{oattr}_names", [f"{m}_{oattr}{j}" for j in range(mod.shape[1 - axis])]) + + mdata = MuData(modalities, axis=axis) + + # names along non-axis are concatenated + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in modalities.values()) + assert ( + getattr(mdata, f"{oattr}_names") + == reduce( + lambda x, y: x.append(y), + (getattr(mod, f"{oattr}_names") for mod in modalities.values()), ) - mods["mod2"] = mods["mod2"][mod2_which_obs].copy() + ).all() - if obs_across: - if obs_across != "intersecting": - raise NotImplementedError("Tests for non-intersecting obs_names are not implemented") + # names along axis are intersected + axisnames = reduce( + lambda x, y: x.union(y, sort=False), + (getattr(mod, f"{attr}_names") for mod in modalities.values()), + ) + assert mdata.shape[axis] == axisnames.shape[0] + assert (getattr(mdata, f"{attr}_names") == axisnames).all() - if obs_mod: - if obs_mod == "duplicated": - for m in ["mod1", "mod2"]: - # Index does not support mutable operations - obs_names = mods[m].obs_names.values.copy() - obs_names[1] = obs_names[0] - mods[m].obs_names = obs_names - elif obs_mod == "extreme_duplicated": - obs_names = mods["mod1"].obs_names.to_numpy() - obs_names[:-1] = obs_names[0] - mods["mod1"].obs_names = obs_names + # Variables are different across modalities + for m, mod in modalities.items(): + # Columns are intact in individual modalities + assert "mod" in mod.obs.columns + assert all(mod.obs["mod"] == m) + assert "mod" in mod.var.columns + assert all(mod.var["mod"] == m) - return mods + @pytest.mark.parametrize("obs_mod", ["unique"]) + @pytest.mark.parametrize("obs_across", ["intersecting"]) + @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) + def test_update_duplicates(self, modalities, axis): + """ + Update should work when + - obs_names are the same across modalities, + - there are duplicated var_names, which are not intersecting + between modalities + """ + attr = "obs" if axis == 0 else "var" + oattr = "var" if axis == 0 else "obs" + for m, mod in modalities.items(): + setattr( + mod, f"{oattr}_names", [f"{m}_{oattr}{j // 2}" for j in range(mod.shape[1 - axis])] + ) + + mdata = MuData(modalities, axis=axis) + + # names along non-axis are concatenated + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in modalities.values()) + assert ( + getattr(mdata, f"{oattr}_names") + == reduce( + lambda x, y: x.append(y), + (getattr(mod, f"{oattr}_names") for mod in modalities.values()), + ) + ).all() + + # names along axis are intersected + axisnames = reduce( + lambda x, y: x.union(y, sort=False), + (getattr(mod, f"{attr}_names") for mod in modalities.values()), + ) + assert mdata.shape[axis] == axisnames.shape[0] + assert (getattr(mdata, f"{attr}_names") == axisnames).all() + + # Variables are different across modalities + for m, mod in modalities.items(): + # Columns are intact in individual modalities + assert "mod" in mod.obs.columns + assert all(mod.obs["mod"] == m) + assert "mod" in mod.var.columns + assert all(mod.var["mod"] == m) + + @pytest.mark.parametrize("obs_mod", ["unique"]) + @pytest.mark.parametrize("obs_across", ["intersecting"]) + @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) + def test_update_intersecting(self, modalities, axis): + """ + Update should work when + - obs_names are the same across modalities, + - there are intersecting var_names, + which are unique in each modality + """ + attr = "obs" if axis == 0 else "var" + oattr = "var" if axis == 0 else "obs" + for m, mod in modalities.items(): + setattr( + mod, + f"{oattr}_names", + [ + f"{m}_{oattr}{j}" if j != 0 else f"{oattr}_{j}" + for j in range(mod.shape[1 - axis]) + ], + ) + + mdata = MuData(modalities, axis=axis) + + # names along non-axis are concatenated + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in modalities.values()) + assert ( + getattr(mdata, f"{oattr}_names") + == reduce( + lambda x, y: x.append(y), + (getattr(mod, f"{oattr}_names") for mod in modalities.values()), + ) + ).all() + + # names along axis are intersected + axisnames = reduce( + lambda x, y: x.union(y, sort=False), + (getattr(mod, f"{attr}_names") for mod in modalities.values()), + ) + assert mdata.shape[axis] == axisnames.shape[0] + assert (getattr(mdata, f"{attr}_names") == axisnames).all() + + # Variables are different across modalities + for m, mod in modalities.items(): + # Columns are intact in individual modalities + assert "mod" in mod.obs.columns + assert all(mod.obs["mod"] == m) + assert "mod" in mod.var.columns + assert all(mod.var["mod"] == m) + + @pytest.mark.parametrize("obs_mod", ["unique"]) + @pytest.mark.parametrize("obs_across", ["intersecting"]) + @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) + def test_update_after_filter_obs_adata(self, modalities, axis): + """ + Check for muon issue #44. + """ + # Replicate in-place filtering in muon: + # mu.pp.filter_obs(mdata['mod1'], 'min_count', lambda x: (x < -2)) + mdata = MuData(modalities, axis=axis) + batches = np.random.choice(["a", "b", "c"], size=mdata.shape[0], replace=True) + mdata.obs["batch"] = batches + + old_obsnames = mdata.obs_names + old_varnames = mdata.var_names + + mdata.mod["mod1"] = mdata["mod1"][mdata["mod1"].obs["min_count"] < -2].copy() + mdata.update() + assert mdata.obs["batch"].isna().sum() == 0 + + assert (mdata.var_names == old_varnames).all() + if axis == 0: + assert (mdata.obs_names == old_obsnames).all() + + @pytest.mark.parametrize("obs_mod", ["unique"]) + @pytest.mark.parametrize("obs_across", ["intersecting"]) + @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) + def test_update_after_obs_reordered(self, modalities, axis): + """ + Update should work if obs are reordered. + """ + mdata = MuData(modalities, axis=axis) + mdata.obsm["test_obsm"] = np.random.normal(size=(mdata.n_obs, 2)) + + some_obs_names = mdata.obs_names.values[:2] + + true_obsm_values = [ + mdata.obsm["test_obsm"][np.where(mdata.obs_names.values == name)[0][0]] + for name in some_obs_names + ] + + mdata.mod["mod1"] = mdata["mod1"][::-1].copy() + mdata.update() + + test_obsm_values = [ + mdata.obsm["test_obsm"][np.where(mdata.obs_names == name)[0][0]] + for name in some_obs_names + ] + + assert all( + [all(true_obsm_values[i] == test_obsm_values[i]) for i in range(len(true_obsm_values))] + ) @pytest.mark.usefixtures("filepath_h5mu") -class TestMuData: - @pytest.mark.parametrize("obs_mod", ["unique", "extreme_duplicated"]) +class TestMuDataLegacy: + @pytest.mark.parametrize("obs_mod", ["unique"]) @pytest.mark.parametrize("obs_across", ["intersecting"]) @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) def test_update_simple(self, modalities): @@ -186,6 +351,7 @@ def test_update_after_filter_obs_adata(self, mdata): # Replicate in-place filtering in muon: # mu.pp.filter_obs(mdata['mod1'], 'min_count', lambda x: (x < -2)) mdata.mod["mod1"] = mdata["mod1"][mdata["mod1"].obs["min_count"] < -2].copy() + old_obsnames = mdata.obs_names mdata.update() assert mdata.obs["batch"].isna().sum() == 0 From 723a63cee331ff7171df9f21673239c7d3c7e90a Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Tue, 30 Sep 2025 11:07:46 +0200 Subject: [PATCH 02/14] make update() work well when indices have duplicates - all tests now also test with duplicated entries - streamline update() to remove redundant code, duplicated calcuations - fix bugs and make it pass all tests --- pyproject.toml | 1 + src/mudata/_core/mudata.py | 214 +++++++++++++------------------------ src/mudata/_core/utils.py | 12 ++- tests/test_update.py | 39 ++++--- 4 files changed, 107 insertions(+), 159 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1e2b5c5..0f0a29d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ ] dependencies = [ "anndata >= 0.10.8", + "pandas >= 1.4.0" ] dynamic = ["version"] diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index f231af5..39e8263 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -599,15 +599,12 @@ def _update_attr( self._update_attr_legacy(attr, axis, join_common, **kwargs) return - prev_index = getattr(self, attr).index - # No _attrhash when upon read # No _attrhash in mudata < 0.2.0 _attrhash = f"_{attr}hash" attr_changed = self._check_changed_attr_names(attr) attr_duplicated = self._check_duplicated_attr_names(attr) - attr_intersecting = self._check_intersecting_attr_names(attr) if attr_duplicated: warnings.warn( @@ -633,8 +630,16 @@ def _update_attr( attrp = getattr(self, attr + "p") attrmap = getattr(self, attr + "map") - # TODO: take advantage when attr_changed[0] == False — only new columns to be added + dfs = [ + getattr(a, attr) + .loc[:, []] + .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) + .add_prefix(m + ":") + for m, a in self.mod.items() + ] + index_order = None + can_update = True # # Join modality .obs/.var tables # @@ -643,36 +648,10 @@ def _update_attr( if not attr_duplicated: # Shared axis if axis == (1 - self._axis) or self._axis == -1: - # We assume attr_intersecting and can't join_common - data_mod = pd.concat( - [ - getattr(a, attr) - .loc[:, []] - .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":") - for m, a in self.mod.items() - ], - join="outer", - axis=1, - sort=False, - ) + data_mod = pd.concat(dfs, join="outer", axis=1, sort=False) else: - data_mod = _maybe_coerce_to_bool( - pd.concat( - [ - getattr(a, attr) - .loc[:, []] - .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":") - for m, a in self.mod.items() - ], - join="outer", - axis=0, - sort=False, - ) - ) - - for mod in self.mod.keys(): + data_mod = _maybe_coerce_to_bool(pd.concat(dfs, join="outer", axis=0, sort=False)) + for mod, amod in self.mod.items(): colname = mod + ":" + rowcol # use 0 as special value for missing # we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write @@ -681,54 +660,38 @@ def _update_attr( col.replace(np.nan, 0, inplace=True) data_mod[colname] = col.astype(np.uint32) - if len(data_global.columns) > 0: - if not attr_intersecting or axis == (1 - self.axis) or self.axis == -1: - data_mod = data_mod.join(data_global, how="left", sort=False) - else: - # In order to preserve the order of the index, instead, - # perform a join based on unique index. - data_mod = _make_index_unique(data_mod) - data_global = _make_index_unique(data_global) - data_mod = data_global.join(data_mod, how="left", sort=False) - data_mod = _restore_index(data_mod) - data_global = _restore_index(data_global) + data_mod = _make_index_unique(data_mod) + data_global = _make_index_unique(data_global) + if data_global.shape[1] > 0: + data_mod = data_global.join(data_mod, how="left", sort=False) + + if data_global.shape[0] > 0: + # reorder new index to conform to the old index as much as possible + kept_idx = data_global.index[data_global.index.isin(data_mod.index)] + new_idx = data_mod.index[~data_mod.index.isin(data_global.index)] + data_mod = data_mod.loc[kept_idx.append(new_idx), :] + + index_order = data_global.index.get_indexer(data_mod.index) + can_update = ( + new_idx.shape[0] == 0 # filtered or reordered + or kept_idx.shape[0] == data_global.shape[0] # new rows only + or data_mod.shape[0] + == data_global.shape[ + 0 + ] # renamed (since new_idx.shape[0] > 0 and kept_idx.shape[0] < data_global.shape[0]) + ) + + data_mod = _restore_index(data_mod) + data_global = _restore_index(data_global) # # General case: with duplicates and/or intersections # else: - dfs = [ - _make_index_unique( - getattr(a, attr) - .loc[:, []] - .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":") - ) - for m, a in self.mod.items() - ] - data_mod = pd.concat( - dfs, - join="outer", - axis=axis, - sort=False, - ) - - # pd.concat wrecks the ordering when doing an outer join with a MultiIndex and different data frame shapes - if axis == 1: - newidx = ( - reduce(lambda x, y: x.union(y, sort=False), (df.index for df in dfs)) - .to_frame() - .reset_index(level=1, drop=True) - ) - globalidx = data_global.index.get_level_values(0) - mask = globalidx.isin(newidx.iloc[:, 0]) - if len(mask) > 0: - negativemask = ~newidx.index.get_level_values(0).isin(globalidx) - newidx = pd.MultiIndex.from_frame( - pd.concat( - [newidx.loc[globalidx[mask], :], newidx.iloc[negativemask, :]], axis=0 - ) - ) - data_mod = data_mod.reindex(newidx, copy=False) + dfs = [_make_index_unique(df, force=True) for df in dfs] + if axis == (1 - self._axis) or self._axis == -1: + data_mod = pd.concat(dfs, join="outer", axis=1, sort=False) + else: + data_mod = _maybe_coerce_to_bool(pd.concat(dfs, join="outer", axis=0, sort=False)) data_mod = _restore_index(data_mod) data_mod.index.set_names(rowcol, inplace=True) @@ -741,13 +704,13 @@ def _update_attr( col = data_mod.loc[:, colname] + 1 col.replace(np.nan, 0, inplace=True) col = col.astype(np.uint32) - data_mod.loc[:, colname] = col - data_mod.set_index(colname, append=True, inplace=True) - if mod in attrmap and np.sum(attrmap[mod] > 0) == getattr(amod, attr).shape[0]: - data_global.set_index(attrmap[mod], append=True, inplace=True) + data_mod[colname] = col + if mod in attrmap and np.array_equal(attrmap[mod], col): + data_mod.set_index(colname, append=True, inplace=True) + data_global.set_index(attrmap[mod].ravel(), append=True, inplace=True) data_global.index.set_names(colname, level=-1, inplace=True) - if len(data_global) > 0: + if data_global.shape[0] > 0: if not data_global.index.is_unique: warnings.warn( f"{attr}_names is not unique, global {attr} is present, and {attr}map is empty. The update() is not well-defined, verify if global {attr} map to the correct modality-specific {attr}.", @@ -759,15 +722,26 @@ def _update_attr( data_mod = _make_index_unique(data_mod) data_global = _make_index_unique(data_global) data_mod = data_mod.join(data_global, how="left", sort=False) + + # reorder new index to conform to the old index as much as possible + kept_idx = data_global.index[data_global.index.isin(data_mod.index)] + new_idx = data_mod.index[~data_mod.index.isin(data_global.index)] + data_mod = data_mod.loc[kept_idx.append(new_idx), :] + + index_order = data_global.index.get_indexer(data_mod.index) + can_update = ( + new_idx.shape[0] == 0 # filtered or reordered + or kept_idx.shape[0] == data_global.shape[0] # new rows only + or data_mod.shape[0] + == data_global.shape[ + 0 + ] # renamed (since new_idx.shape[0] > 0 and kept_idx.shape[0] < data_global.shape[0]) + ) + data_mod.reset_index(level=list(range(1, data_mod.index.nlevels)), inplace=True) + data_global.reset_index(level=list(range(1, data_global.index.nlevels)), inplace=True) data_mod.index.set_names(None, inplace=True) - - if data_global.shape[0] > 0: - # reorder new index to conform to the old index as much as possible - kept_idx = data_global.index.isin(data_mod.index) - data_mod = data_mod.loc[ - data_global.index[kept_idx].append(data_global.index[~kept_idx]), : - ] + data_global.index.set_names(None, inplace=True) # get adata positions and remove columns from the data frame mdict = {} @@ -789,48 +763,9 @@ def _update_attr( for mod, mapping in mdict.items(): attrm[mod] = mapping > 0 - now_index = data_mod.index - - if len(prev_index) == 0: - # New object - pass - elif now_index.equals(prev_index): - # Index is the same - pass - else: - keep_index = prev_index.isin(now_index) - new_index = ~now_index.isin(prev_index) - - if new_index.sum() == 0 or ( - keep_index.sum() + new_index.sum() == len(now_index) - and len(now_index) > len(prev_index) - ): - # Another length (filtered) or new modality added - # Update .obsm/.varm (size might have changed) - # NOTE: .get_index doesn't work with duplicated indices - if any(prev_index.duplicated()): - # Assume the relative order of duplicates hasn't changed - # NOTE: .get_loc() for each element is too slow - # We will rename duplicated in prev_index and now_index - # in order to use .get_indexer - # index_order = [ - # prev_index.get_loc(i) if i in prev_index else -1 for i in now_index - # ] - prev_values = prev_index.values.copy() - now_values = now_index.values.copy() - for value in prev_index[np.where(prev_index.duplicated())[0]]: - v_now = np.where(now_index == value)[0] - v_prev = np.where(prev_index.get_loc(value))[0] - for i in range(min(len(v_now), len(v_prev))): - prev_values[v_prev[i]] = f"{str(value)}-{i}" - now_values[v_now[i]] = f"{str(value)}-{i}" - - prev_index = pd.Index(prev_values) - now_index = pd.Index(now_values) - - index_order = prev_index.get_indexer(now_index) - - for mx_key in attrm.keys(): + if index_order is not None: + if can_update: + for mx_key, mx in attrm.items(): if mx_key not in self.mod.keys(): # not a modality name attrm[mx_key] = attrm[mx_key][index_order] attrm[mx_key][index_order == -1] = np.nan @@ -841,11 +776,6 @@ def _update_attr( attrp[mx_key][index_order == -1, :] = -1 attrp[mx_key][:, index_order == -1] = -1 - elif len(now_index) == len(prev_index): - # Renamed since new_index.sum() != 0 - # We have to assume the order hasn't changed - pass - else: raise NotImplementedError( f"{attr}_names seem to have been renamed and filtered at the same time. " @@ -1078,7 +1008,8 @@ def _update_attr_legacy( getattr(a, attr) .drop(columns_common, axis=1) .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":") + .add_prefix(m + ":"), + force=True, ) ) for m, a in self.mod.items() @@ -1095,7 +1026,7 @@ def _update_attr_legacy( data_common = pd.concat( [ _maybe_coerce_to_boolean( - _make_index_unique(getattr(a, attr)[columns_common]) + _make_index_unique(getattr(a, attr)[columns_common], force=True) ) for m, a in self.mod.items() ], @@ -1111,7 +1042,8 @@ def _update_attr_legacy( _make_index_unique( getattr(a, attr) .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":") + .add_prefix(m + ":"), + force=True, ) for m, a in self.mod.items() ] @@ -1173,8 +1105,8 @@ def _update_attr_legacy( data_mod.reset_index( data_mod.index.names.difference(data_global.index.names), inplace=True ) - data_mod = _make_index_unique(data_mod) - data_global = _make_index_unique(data_global) + data_mod = _make_index_unique(data_mod, force=True) + data_global = _make_index_unique(data_global, force=True) data_mod = data_mod.join(data_global, how="left", sort=False) data_mod.reset_index(level=list(range(1, data_mod.index.nlevels)), inplace=True) data_mod.index.set_names(None, inplace=True) diff --git a/src/mudata/_core/utils.py b/src/mudata/_core/utils.py index f3daa95..1cfc8b7 100644 --- a/src/mudata/_core/utils.py +++ b/src/mudata/_core/utils.py @@ -8,9 +8,9 @@ T = TypeVar("T", pd.Series, pd.DataFrame) -def _make_index_unique(df: pd.DataFrame) -> pd.DataFrame: - dup_idx = np.zeros((df.shape[0],), dtype=np.uint8) - if not df.index.is_unique: +def _make_index_unique(df: pd.DataFrame, force: bool = False) -> pd.DataFrame: + if force or not df.index.is_unique: + dup_idx = np.zeros((df.shape[0],), dtype=np.uint8) duplicates = np.nonzero(df.index.duplicated())[0] cnt = Counter() for dup in duplicates: @@ -22,11 +22,13 @@ def _make_index_unique(df: pd.DataFrame) -> pd.DataFrame: dup_idx = dup_idx.astype(np.min_scalar_type(newval)) dup_idx[dup] = newval cnt[idxval] = newval - return df.set_index(dup_idx, append=True) + return df.set_index(dup_idx, append=True) + else: + return df def _restore_index(df: pd.DataFrame) -> pd.DataFrame: - return df.reset_index(level=-1, drop=True) + return df.reset_index(level=-1, drop=True) if df.index.nlevels > 1 else df def _maybe_coerce_to_boolean(df: T) -> T: diff --git a/tests/test_update.py b/tests/test_update.py index 1823a3c..0012d5e 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -24,9 +24,9 @@ def modalities(request, obs_n, obs_across, obs_mod): if obs_n: if obs_n == "disjoint": mod2_which_obs = np.random.choice( - mods["mod2"].obs_names, size=mods["mod2"].n_obs // 2, replace=False + mods["mod1"].obs_names, size=mods["mod1"].n_obs // 2, replace=False ) - mods["mod2"] = mods["mod2"][mod2_which_obs].copy() + mods["mod1"] = mods["mod1"][mod2_which_obs].copy() if obs_across: if obs_across != "intersecting": @@ -66,7 +66,7 @@ def new_update(self): yield set_options(pull_on_update=None) - @pytest.mark.parametrize("obs_mod", ["unique"]) + @pytest.mark.parametrize("obs_mod", ["unique", "duplicated"]) @pytest.mark.parametrize("obs_across", ["intersecting"]) @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) def test_update_simple(self, modalities, axis): @@ -77,8 +77,6 @@ def test_update_simple(self, modalities, axis): """ attr = "obs" if axis == 0 else "var" oattr = "var" if axis == 0 else "obs" - for m, mod in modalities.items(): - setattr(mod, f"{oattr}_names", [f"{m}_{oattr}{j}" for j in range(mod.shape[1 - axis])]) mdata = MuData(modalities, axis=axis) @@ -92,13 +90,27 @@ def test_update_simple(self, modalities, axis): ) ).all() - # names along axis are intersected + # names along axis are unioned axisnames = reduce( lambda x, y: x.union(y, sort=False), (getattr(mod, f"{attr}_names") for mod in modalities.values()), ) assert mdata.shape[axis] == axisnames.shape[0] - assert (getattr(mdata, f"{attr}_names") == axisnames).all() + assert (getattr(mdata, f"{attr}_names").sort_values() == axisnames.sort_values()).all() + + # guards against Pandas scrambling the order. This was the case for pandas < 1.4.0 when using pd.concat with an outer join on a MultiIndex. + # reprex: + # + # import numpy as np + # import pandas as pd + # df1 = pd.DataFrame({"a": np.repeat(np.arange(5), 2), "b": np.tile(np.asarray([0, 1]), 5), "c": np.arange(10)}).set_index("a").set_index("b", append=True) + # df2 = pd.DataFrame({"a": np.repeat(np.arange(10), 2), "b": np.tile(np.asarray([0, 1]), 10), "d": np.arange(20)}).set_index("a").set_index("b", append=True) + # df1 = df1.iloc[::-1, :] + # df = pd.concat((kdf1, df2), axis=1, join="outer", sort=False) + assert ( + getattr(mdata, f"{attr}_names")[: modalities["mod1"].shape[axis]] + == getattr(modalities["mod1"], f"{attr}_names") + ).all() # Variables are different across modalities for m, mod in modalities.items(): @@ -108,7 +120,7 @@ def test_update_simple(self, modalities, axis): assert "mod" in mod.var.columns assert all(mod.var["mod"] == m) - @pytest.mark.parametrize("obs_mod", ["unique"]) + @pytest.mark.parametrize("obs_mod", ["unique", "duplicated"]) @pytest.mark.parametrize("obs_across", ["intersecting"]) @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) def test_update_duplicates(self, modalities, axis): @@ -153,7 +165,7 @@ def test_update_duplicates(self, modalities, axis): assert "mod" in mod.var.columns assert all(mod.var["mod"] == m) - @pytest.mark.parametrize("obs_mod", ["unique"]) + @pytest.mark.parametrize("obs_mod", ["unique", "duplicated"]) @pytest.mark.parametrize("obs_across", ["intersecting"]) @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) def test_update_intersecting(self, modalities, axis): @@ -203,7 +215,7 @@ def test_update_intersecting(self, modalities, axis): assert "mod" in mod.var.columns assert all(mod.var["mod"] == m) - @pytest.mark.parametrize("obs_mod", ["unique"]) + @pytest.mark.parametrize("obs_mod", ["unique", "duplicated"]) @pytest.mark.parametrize("obs_across", ["intersecting"]) @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) def test_update_after_filter_obs_adata(self, modalities, axis): @@ -219,15 +231,16 @@ def test_update_after_filter_obs_adata(self, modalities, axis): old_obsnames = mdata.obs_names old_varnames = mdata.var_names - mdata.mod["mod1"] = mdata["mod1"][mdata["mod1"].obs["min_count"] < -2].copy() + mdata.mod["mod3"] = mdata["mod3"][mdata["mod3"].obs["min_count"] < -2].copy() mdata.update() assert mdata.obs["batch"].isna().sum() == 0 assert (mdata.var_names == old_varnames).all() if axis == 0: - assert (mdata.obs_names == old_obsnames).all() + # check if the order is preserved + assert (mdata.obs_names == old_obsnames[old_obsnames.isin(mdata.obs_names)]).all() - @pytest.mark.parametrize("obs_mod", ["unique"]) + @pytest.mark.parametrize("obs_mod", ["unique", "duplicated"]) @pytest.mark.parametrize("obs_across", ["intersecting"]) @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) def test_update_after_obs_reordered(self, modalities, axis): From 22a1b144424e284dc478a2e84579b9b929018da2 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Tue, 30 Sep 2025 16:57:27 +0200 Subject: [PATCH 03/14] fix inserting modalities --- src/mudata/_core/mudata.py | 22 ++++-- tests/test_update.py | 153 +++++++++++++++++-------------------- 2 files changed, 87 insertions(+), 88 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 39e8263..998dcf5 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -605,6 +605,7 @@ def _update_attr( attr_changed = self._check_changed_attr_names(attr) attr_duplicated = self._check_duplicated_attr_names(attr) + attr_intersecting = self._check_intersecting_attr_names(attr) if attr_duplicated: warnings.warn( @@ -660,8 +661,8 @@ def _update_attr( col.replace(np.nan, 0, inplace=True) data_mod[colname] = col.astype(np.uint32) - data_mod = _make_index_unique(data_mod) - data_global = _make_index_unique(data_global) + data_mod = _make_index_unique(data_mod, force=attr_intersecting) + data_global = _make_index_unique(data_global, force=attr_intersecting) if data_global.shape[1] > 0: data_mod = data_global.join(data_mod, how="left", sort=False) @@ -705,7 +706,13 @@ def _update_attr( col.replace(np.nan, 0, inplace=True) col = col.astype(np.uint32) data_mod[colname] = col - if mod in attrmap and np.array_equal(attrmap[mod], col): + if mod in attrmap and ( + col.shape[0] != data_global.shape[0] + and np.sum(attrmap[mod] > 0) + == getattr(amod, attr).shape[0] # added/removed observations + or col.shape[0] == data_global.shape[0] + and np.array_equal(attrmap[mod], col) # reordered + ): data_mod.set_index(colname, append=True, inplace=True) data_global.set_index(attrmap[mod].ravel(), append=True, inplace=True) data_global.index.set_names(colname, level=-1, inplace=True) @@ -719,8 +726,10 @@ def _update_attr( data_mod.reset_index( data_mod.index.names.difference(data_global.index.names), inplace=True ) - data_mod = _make_index_unique(data_mod) - data_global = _make_index_unique(data_global) + # after inserting a new modality with duplicates, but no duplicates before: + # data_mod.index is not unique + data_global = _make_index_unique(data_global, force=not data_mod.index.is_unique) + data_mod = _make_index_unique(data_mod) data_mod = data_mod.join(data_global, how="left", sort=False) # reorder new index to conform to the old index as much as possible @@ -736,6 +745,9 @@ def _update_attr( == data_global.shape[ 0 ] # renamed (since new_idx.shape[0] > 0 and kept_idx.shape[0] < data_global.shape[0]) + or axis == self._axis + and data_mod.shape[0] + > data_global.shape[0] # new modality added and concacenated ) data_mod.reset_index(level=list(range(1, data_mod.index.nlevels)), inplace=True) diff --git a/tests/test_update.py b/tests/test_update.py index 0012d5e..467aeea 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -44,11 +44,15 @@ def modalities(request, obs_n, obs_across, obs_mod): obs_names[:-1] = obs_names[0] mods["mod1"].obs_names = obs_names + var_names = mods[m].var_names.values.copy() + var_names[1] = var_names[0] + mods[m].var_names = var_names + return mods @pytest.fixture() -def mdata(modalities): +def mdata_legacy(modalities): mdata = MuData(modalities) batches = np.random.choice(["a", "b", "c"], size=mdata.shape[0], replace=True) @@ -57,8 +61,24 @@ def mdata(modalities): return mdata +@pytest.fixture() +def mdata(modalities, axis): + md = MuData(modalities, axis=axis) + + md.obs["batch"] = np.random.choice(["a", "b", "c"], size=md.shape[0], replace=True) + md.var["batch"] = np.random.choice(["d", "e", "f"], size=md.shape[1], replace=True) + + md.obsm["test_obsm"] = np.random.normal(size=(md.n_obs, 2)) + md.varm["test_varm"] = np.random.normal(size=(md.n_var, 2)) + + return md + + @pytest.mark.usefixtures("filepath_h5mu") @pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("obs_mod", ["unique", "duplicated"]) +@pytest.mark.parametrize("obs_across", ["intersecting"]) +@pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) class TestMuData: @pytest.fixture(autouse=True) def new_update(self): @@ -66,10 +86,7 @@ def new_update(self): yield set_options(pull_on_update=None) - @pytest.mark.parametrize("obs_mod", ["unique", "duplicated"]) - @pytest.mark.parametrize("obs_across", ["intersecting"]) - @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) - def test_update_simple(self, modalities, axis): + def test_update_simple(self, mdata, axis): """ Update should work when - obs_names are the same across modalities, @@ -78,22 +95,20 @@ def test_update_simple(self, modalities, axis): attr = "obs" if axis == 0 else "var" oattr = "var" if axis == 0 else "obs" - mdata = MuData(modalities, axis=axis) - # names along non-axis are concatenated - assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in modalities.values()) + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) assert ( getattr(mdata, f"{oattr}_names") == reduce( lambda x, y: x.append(y), - (getattr(mod, f"{oattr}_names") for mod in modalities.values()), + (getattr(mod, f"{oattr}_names") for mod in mdata.mod.values()), ) ).all() # names along axis are unioned axisnames = reduce( lambda x, y: x.union(y, sort=False), - (getattr(mod, f"{attr}_names") for mod in modalities.values()), + (getattr(mod, f"{attr}_names") for mod in mdata.mod.values()), ) assert mdata.shape[axis] == axisnames.shape[0] assert (getattr(mdata, f"{attr}_names").sort_values() == axisnames.sort_values()).all() @@ -108,66 +123,48 @@ def test_update_simple(self, modalities, axis): # df1 = df1.iloc[::-1, :] # df = pd.concat((kdf1, df2), axis=1, join="outer", sort=False) assert ( - getattr(mdata, f"{attr}_names")[: modalities["mod1"].shape[axis]] - == getattr(modalities["mod1"], f"{attr}_names") + getattr(mdata, f"{attr}_names")[: mdata["mod1"].shape[axis]] + == getattr(mdata["mod1"], f"{attr}_names") ).all() # Variables are different across modalities - for m, mod in modalities.items(): + for m, mod in mdata.mod.items(): # Columns are intact in individual modalities assert "mod" in mod.obs.columns assert all(mod.obs["mod"] == m) assert "mod" in mod.var.columns assert all(mod.var["mod"] == m) - @pytest.mark.parametrize("obs_mod", ["unique", "duplicated"]) - @pytest.mark.parametrize("obs_across", ["intersecting"]) - @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) - def test_update_duplicates(self, modalities, axis): - """ - Update should work when - - obs_names are the same across modalities, - - there are duplicated var_names, which are not intersecting - between modalities - """ + def test_update_add_modality(self, modalities, axis): + modnames = list(modalities.keys()) + mdata = MuData({modname: modalities[modname] for modname in modnames[:-2]}, axis=axis) + attr = "obs" if axis == 0 else "var" oattr = "var" if axis == 0 else "obs" - for m, mod in modalities.items(): - setattr( - mod, f"{oattr}_names", [f"{m}_{oattr}{j // 2}" for j in range(mod.shape[1 - axis])] - ) - mdata = MuData(modalities, axis=axis) + for i in (-2, -1): + old_attrnames = getattr(mdata, f"{attr}_names") + old_oattrnames = getattr(mdata, f"{oattr}_names") + + mdata.mod[modnames[i]] = modalities[modnames[i]] + mdata.update() + + attrnames = getattr(mdata, f"{attr}_names") + oattrnames = getattr(mdata, f"{oattr}_names") + assert (attrnames[: old_attrnames.size] == old_attrnames).all() + assert (oattrnames[: old_oattrnames.size] == old_oattrnames).all() + + assert ( + attrnames + == old_attrnames.union( + getattr(modalities[modnames[i]], f"{attr}_names"), sort=False + ) + ).all() + assert ( + oattrnames + == old_oattrnames.append(getattr(modalities[modnames[i]], f"{oattr}_names")) + ).all() - # names along non-axis are concatenated - assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in modalities.values()) - assert ( - getattr(mdata, f"{oattr}_names") - == reduce( - lambda x, y: x.append(y), - (getattr(mod, f"{oattr}_names") for mod in modalities.values()), - ) - ).all() - - # names along axis are intersected - axisnames = reduce( - lambda x, y: x.union(y, sort=False), - (getattr(mod, f"{attr}_names") for mod in modalities.values()), - ) - assert mdata.shape[axis] == axisnames.shape[0] - assert (getattr(mdata, f"{attr}_names") == axisnames).all() - - # Variables are different across modalities - for m, mod in modalities.items(): - # Columns are intact in individual modalities - assert "mod" in mod.obs.columns - assert all(mod.obs["mod"] == m) - assert "mod" in mod.var.columns - assert all(mod.var["mod"] == m) - - @pytest.mark.parametrize("obs_mod", ["unique", "duplicated"]) - @pytest.mark.parametrize("obs_across", ["intersecting"]) - @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) def test_update_intersecting(self, modalities, axis): """ Update should work when @@ -215,18 +212,12 @@ def test_update_intersecting(self, modalities, axis): assert "mod" in mod.var.columns assert all(mod.var["mod"] == m) - @pytest.mark.parametrize("obs_mod", ["unique", "duplicated"]) - @pytest.mark.parametrize("obs_across", ["intersecting"]) - @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) - def test_update_after_filter_obs_adata(self, modalities, axis): + def test_update_after_filter_obs_adata(self, mdata, axis): """ Check for muon issue #44. """ # Replicate in-place filtering in muon: # mu.pp.filter_obs(mdata['mod1'], 'min_count', lambda x: (x < -2)) - mdata = MuData(modalities, axis=axis) - batches = np.random.choice(["a", "b", "c"], size=mdata.shape[0], replace=True) - mdata.obs["batch"] = batches old_obsnames = mdata.obs_names old_varnames = mdata.var_names @@ -240,16 +231,10 @@ def test_update_after_filter_obs_adata(self, modalities, axis): # check if the order is preserved assert (mdata.obs_names == old_obsnames[old_obsnames.isin(mdata.obs_names)]).all() - @pytest.mark.parametrize("obs_mod", ["unique", "duplicated"]) - @pytest.mark.parametrize("obs_across", ["intersecting"]) - @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) - def test_update_after_obs_reordered(self, modalities, axis): + def test_update_after_obs_reordered(self, mdata): """ Update should work if obs are reordered. """ - mdata = MuData(modalities, axis=axis) - mdata.obsm["test_obsm"] = np.random.normal(size=(mdata.n_obs, 2)) - some_obs_names = mdata.obs_names.values[:2] true_obsm_values = [ @@ -357,38 +342,40 @@ def test_update_intersecting(self, modalities): @pytest.mark.parametrize("obs_mod", ["unique"]) @pytest.mark.parametrize("obs_across", ["intersecting"]) @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) - def test_update_after_filter_obs_adata(self, mdata): + def test_update_after_filter_obs_adata(self, mdata_legacy): """ Check for muon issue #44. """ # Replicate in-place filtering in muon: # mu.pp.filter_obs(mdata['mod1'], 'min_count', lambda x: (x < -2)) - mdata.mod["mod1"] = mdata["mod1"][mdata["mod1"].obs["min_count"] < -2].copy() - old_obsnames = mdata.obs_names - mdata.update() - assert mdata.obs["batch"].isna().sum() == 0 + mdata_legacy.mod["mod1"] = mdata_legacy["mod1"][ + mdata_legacy["mod1"].obs["min_count"] < -2 + ].copy() + old_obsnames = mdata_legacy.obs_names + mdata_legacy.update() + assert mdata_legacy.obs["batch"].isna().sum() == 0 @pytest.mark.parametrize("obs_mod", ["unique", "extreme_duplicated"]) @pytest.mark.parametrize("obs_across", ["intersecting"]) @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) - def test_update_after_obs_reordered(self, mdata): + def test_update_after_obs_reordered(self, mdata_legacy): """ Update should work if obs are reordered. """ - mdata.obsm["test_obsm"] = np.random.normal(size=(mdata.n_obs, 2)) + mdata_legacy.obsm["test_obsm"] = np.random.normal(size=(mdata_legacy.n_obs, 2)) - some_obs_names = mdata.obs_names.values[:2] + some_obs_names = mdata_legacy.obs_names.values[:2] true_obsm_values = [ - mdata.obsm["test_obsm"][np.where(mdata.obs_names.values == name)[0][0]] + mdata_legacy.obsm["test_obsm"][np.where(mdata_legacy.obs_names.values == name)[0][0]] for name in some_obs_names ] - mdata.mod["mod1"] = mdata["mod1"][::-1].copy() - mdata.update() + mdata_legacy.mod["mod1"] = mdata_legacy["mod1"][::-1].copy() + mdata_legacy.update() test_obsm_values = [ - mdata.obsm["test_obsm"][np.where(mdata.obs_names == name)[0][0]] + mdata_legacy.obsm["test_obsm"][np.where(mdata_legacy.obs_names == name)[0][0]] for name in some_obs_names ] From dfe9ec04b353605ce38d0a2dc6dcc92ca1f00b05 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Wed, 1 Oct 2025 11:11:45 +0200 Subject: [PATCH 04/14] fix deleting modalities --- src/mudata/_core/mudata.py | 30 +++++++++++++++++++----------- tests/test_update.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 998dcf5..cb47fa1 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -452,6 +452,8 @@ def _check_changed_attr_names(self, attr: str, columns: bool = False): attr_names_changed, attr_columns_changed = False, False if not hasattr(self, attrhash): attr_names_changed, attr_columns_changed = True, True + elif len(self.mod) < len(getattr(self, attrhash)): + attr_names_changed, attr_columns_changed = True, None else: for m in self.mod.keys(): if m in getattr(self, attrhash): @@ -604,7 +606,13 @@ def _update_attr( _attrhash = f"_{attr}hash" attr_changed = self._check_changed_attr_names(attr) - attr_duplicated = self._check_duplicated_attr_names(attr) + if not any(attr_changed): + # Nothing to update + return + + data_global = getattr(self, attr) + + attr_duplicated = not data_global.index.is_unique or self._check_duplicated_attr_names(attr) attr_intersecting = self._check_intersecting_attr_names(attr) if attr_duplicated: @@ -618,12 +626,6 @@ def _update_attr( stacklevel=2, ) - if not any(attr_changed): - # Nothing to update - return - - data_global = getattr(self, attr) - # Generate unique colnames (rowcol,) = self._find_unique_colnames(attr, 1) @@ -645,7 +647,6 @@ def _update_attr( # Join modality .obs/.var tables # # Main case: no duplicates and no intersection if the axis is not shared - # if not attr_duplicated: # Shared axis if axis == (1 - self._axis) or self._axis == -1: @@ -664,7 +665,7 @@ def _update_attr( data_mod = _make_index_unique(data_mod, force=attr_intersecting) data_global = _make_index_unique(data_global, force=attr_intersecting) if data_global.shape[1] > 0: - data_mod = data_global.join(data_mod, how="left", sort=False) + data_mod = data_mod.join(data_global, how="left", sort=False) if data_global.shape[0] > 0: # reorder new index to conform to the old index as much as possible @@ -728,8 +729,11 @@ def _update_attr( ) # after inserting a new modality with duplicates, but no duplicates before: # data_mod.index is not unique - data_global = _make_index_unique(data_global, force=not data_mod.index.is_unique) - data_mod = _make_index_unique(data_mod) + # after deleting a modality with duplicates: data_global.index is not unique, but + # data_mod.index is unique + need_unique = data_mod.index.is_unique | data_global.index.is_unique + data_global = _make_index_unique(data_global, force=need_unique) + data_mod = _make_index_unique(data_mod, force=need_unique) data_mod = data_mod.join(data_global, how="left", sort=False) # reorder new index to conform to the old index as much as possible @@ -750,6 +754,10 @@ def _update_attr( > data_global.shape[0] # new modality added and concacenated ) + if need_unique: + data_mod = _restore_index(data_mod) + data_global = _restore_index(data_global) + data_mod.reset_index(level=list(range(1, data_mod.index.nlevels)), inplace=True) data_global.reset_index(level=list(range(1, data_global.index.nlevels)), inplace=True) data_mod.index.set_names(None, inplace=True) diff --git a/tests/test_update.py b/tests/test_update.py index 467aeea..0de45c0 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -165,6 +165,38 @@ def test_update_add_modality(self, modalities, axis): == old_oattrnames.append(getattr(modalities[modnames[i]], f"{oattr}_names")) ).all() + def test_update_delete_modality(self, mdata, axis): + modnames = list(mdata.mod.keys()) + attr = "obs" if axis == 0 else "var" + oattr = "var" if axis == 0 else "obs" + + fullbatch = getattr(mdata, attr)["batch"] + fullobatch = getattr(mdata, oattr)["batch"] + keptmask = (getattr(mdata, f"{attr}map")[modnames[1]].reshape(-1) > 0) | ( + getattr(mdata, f"{attr}map")[modnames[2]].reshape(-1) > 0 + ) + keptomask = (getattr(mdata, f"{oattr}map")[modnames[1]].reshape(-1) > 0) | ( + getattr(mdata, f"{oattr}map")[modnames[2]].reshape(-1) > 0 + ) + + del mdata.mod[modnames[0]] + mdata.update() + + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) + assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all() + assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all() + + fullbatch = getattr(mdata, attr)["batch"] + fullobatch = getattr(mdata, oattr)["batch"] + keptmask = getattr(mdata, f"{attr}map")[modnames[1]].reshape(-1) > 0 + keptomask = getattr(mdata, f"{oattr}map")[modnames[1]].reshape(-1) > 0 + del mdata.mod[modnames[2]] + mdata.update() + + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) + assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all() + assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all() + def test_update_intersecting(self, modalities, axis): """ Update should work when From 0be9ecf0229f517357a22a519f931c23f2ca73d0 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Wed, 1 Oct 2025 14:37:36 +0200 Subject: [PATCH 05/14] remove obsolete asserts --- tests/test_update.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_update.py b/tests/test_update.py index 0de45c0..9a3d6f3 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -127,14 +127,6 @@ def test_update_simple(self, mdata, axis): == getattr(mdata["mod1"], f"{attr}_names") ).all() - # Variables are different across modalities - for m, mod in mdata.mod.items(): - # Columns are intact in individual modalities - assert "mod" in mod.obs.columns - assert all(mod.obs["mod"] == m) - assert "mod" in mod.var.columns - assert all(mod.var["mod"] == m) - def test_update_add_modality(self, modalities, axis): modnames = list(modalities.keys()) mdata = MuData({modname: modalities[modname] for modname in modnames[:-2]}, axis=axis) From 082433f4a8d8d81925034818e674e7e67a6c57df Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Thu, 2 Oct 2025 18:14:32 +0200 Subject: [PATCH 06/14] more comprehensive tests and fixes for _update_attr() --- src/mudata/_core/mudata.py | 62 ++++++++++++++----------- tests/test_update.py | 94 +++++++++++++++++++++++++------------- 2 files changed, 99 insertions(+), 57 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index cb47fa1..c868d51 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -611,21 +611,11 @@ def _update_attr( return data_global = getattr(self, attr) + prev_index = data_global.index attr_duplicated = not data_global.index.is_unique or self._check_duplicated_attr_names(attr) attr_intersecting = self._check_intersecting_attr_names(attr) - if attr_duplicated: - warnings.warn( - f"{attr}_names are not unique. To make them unique, call `.{attr}_names_make_unique`.", - stacklevel=2, - ) - if self._axis == -1: - warnings.warn( - f"Behaviour is not defined with axis=-1, {attr}_names need to be made unique first.", - stacklevel=2, - ) - # Generate unique colnames (rowcol,) = self._find_unique_colnames(attr, 1) @@ -707,16 +697,19 @@ def _update_attr( col.replace(np.nan, 0, inplace=True) col = col.astype(np.uint32) data_mod[colname] = col - if mod in attrmap and ( - col.shape[0] != data_global.shape[0] - and np.sum(attrmap[mod] > 0) - == getattr(amod, attr).shape[0] # added/removed observations - or col.shape[0] == data_global.shape[0] - and np.array_equal(attrmap[mod], col) # reordered - ): - data_mod.set_index(colname, append=True, inplace=True) - data_global.set_index(attrmap[mod].ravel(), append=True, inplace=True) - data_global.index.set_names(colname, level=-1, inplace=True) + if mod in attrmap: + modmap = attrmap[mod].reshape(-1) + modmask = modmap > 0 + # only use unchanged modalities for ordering + if ( + modmask.sum() == getattr(amod, attr).shape[0] + and ( + getattr(amod, attr).index[modmap[modmask] - 1] == prev_index[modmask] + ).all() + ): + data_mod.set_index(colname, append=True, inplace=True) + data_global.set_index(attrmap[mod].reshape(-1), append=True, inplace=True) + data_global.index.set_names(colname, level=-1, inplace=True) if data_global.shape[0] > 0: if not data_global.index.is_unique: @@ -749,9 +742,11 @@ def _update_attr( == data_global.shape[ 0 ] # renamed (since new_idx.shape[0] > 0 and kept_idx.shape[0] < data_global.shape[0]) - or axis == self._axis - and data_mod.shape[0] - > data_global.shape[0] # new modality added and concacenated + or ( + axis == self._axis + and axis != -1 + and data_mod.shape[0] > data_global.shape[0] + ) # new modality added and concacenated ) if need_unique: @@ -770,6 +765,15 @@ def _update_attr( mdict[m] = data_mod[colname].to_numpy() data_mod.drop(colname, axis=1, inplace=True) + if not data_mod.index.is_unique: + warnings.warn( + f"{attr}_names are not unique. To make them unique, call `.{attr}_names_make_unique`." + ) + if self._axis == -1: + warnings.warn( + f"Behaviour is not defined with axis=-1, {attr}_names need to be made unique first." + ) + setattr( self, "_" + attr, @@ -787,8 +791,14 @@ def _update_attr( if can_update: for mx_key, mx in attrm.items(): if mx_key not in self.mod.keys(): # not a modality name - attrm[mx_key] = attrm[mx_key][index_order] - attrm[mx_key][index_order == -1] = np.nan + cattr = attrm[mx_key] + if isinstance(cattr, pd.DataFrame): + cattr = cattr.iloc[index_order, :] + cattr.iloc[index_order == -1, :] = pd.NA + else: + cattr = cattr[index_order] + cattr[index_order == -1] = np.nan + attrm[mx_key] = cattr # Update .obsp/.varp (size might have changed) for mx_key in attrp.keys(): diff --git a/tests/test_update.py b/tests/test_update.py index 9a3d6f3..76117b4 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -34,19 +34,21 @@ def modalities(request, obs_n, obs_across, obs_mod): if obs_mod: if obs_mod == "duplicated": - for m in ["mod1", "mod2"]: - # Index does not support mutable operations - obs_names = mods[m].obs_names.values.copy() - obs_names[1] = obs_names[0] - mods[m].obs_names = obs_names + obsnames2 = mods["mod2"].obs_names.to_numpy() + obsnames3 = mods["mod3"].obs_names.to_numpy() + varnames2 = mods["mod2"].var_names.to_numpy() + varnames3 = mods["mod3"].var_names.to_numpy() + obsnames2[0] = obsnames2[1] = obsnames3[1] = "testobs" + varnames2[0] = varnames2[1] = varnames3[1] = "testvar" + mods["mod2"].obs_names = obsnames2 + mods["mod3"].obs_names = obsnames3 + mods["mod2"].var_names = varnames2 + mods["mod3"].var_names = varnames3 elif obs_mod == "extreme_duplicated": # integer overflow: https://github.com/scverse/mudata/issues/107 obs_names = mods["mod1"].obs_names.to_numpy() obs_names[:-1] = obs_names[0] mods["mod1"].obs_names = obs_names - var_names = mods[m].var_names.values.copy() - var_names[1] = var_names[0] - mods[m].var_names = var_names return mods @@ -68,8 +70,8 @@ def mdata(modalities, axis): md.obs["batch"] = np.random.choice(["a", "b", "c"], size=md.shape[0], replace=True) md.var["batch"] = np.random.choice(["d", "e", "f"], size=md.shape[1], replace=True) - md.obsm["test_obsm"] = np.random.normal(size=(md.n_obs, 2)) - md.varm["test_varm"] = np.random.normal(size=(md.n_var, 2)) + md.obsm["test"] = np.random.normal(size=(md.n_obs, 2)) + md.varm["test"] = np.random.normal(size=(md.n_var, 2)) return md @@ -86,6 +88,14 @@ def new_update(self): yield set_options(pull_on_update=None) + @staticmethod + def get_attrm_values(mdata, attr, key, names): + attrm = getattr(mdata, f"{attr}m") + index = getattr(mdata, f"{attr}_names") + return np.concatenate( + [np.atleast_1d(attrm[key][np.nonzero(index == name)[0]]) for name in names] + ) + def test_update_simple(self, mdata, axis): """ Update should work when @@ -138,9 +148,25 @@ def test_update_add_modality(self, modalities, axis): old_attrnames = getattr(mdata, f"{attr}_names") old_oattrnames = getattr(mdata, f"{oattr}_names") + some_obs_names = mdata.obs_names[:2] + mdata.obsm["test"] = np.random.normal(size=(mdata.n_obs, 1)) + true_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) + mdata.mod[modnames[i]] = modalities[modnames[i]] mdata.update() + test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) + if axis == 1: + assert np.isnan(mdata.obsm["test"]).sum() == modalities[modnames[i]].n_obs + assert np.all(np.isnan(mdata.obsm["test"][-modalities[modnames[i]].n_obs :])) + assert np.all(~np.isnan(mdata.obsm["test"][: -modalities[modnames[i]].n_obs])) + assert ( + test_obsm_values[~np.isnan(test_obsm_values)].reshape(-1) + == true_obsm_values.reshape(-1) + ).all() + else: + assert (test_obsm_values == true_obsm_values).all() + attrnames = getattr(mdata, f"{attr}_names") oattrnames = getattr(mdata, f"{oattr}_names") assert (attrnames[: old_attrnames.size] == old_attrnames).all() @@ -161,9 +187,13 @@ def test_update_delete_modality(self, mdata, axis): modnames = list(mdata.mod.keys()) attr = "obs" if axis == 0 else "var" oattr = "var" if axis == 0 else "obs" + attrm = f"{attr}m" + oattrm = f"{oattr}m" fullbatch = getattr(mdata, attr)["batch"] fullobatch = getattr(mdata, oattr)["batch"] + fulltestm = getattr(mdata, attrm)["test"] + fullotestm = getattr(mdata, oattrm)["test"] keptmask = (getattr(mdata, f"{attr}map")[modnames[1]].reshape(-1) > 0) | ( getattr(mdata, f"{attr}map")[modnames[2]].reshape(-1) > 0 ) @@ -175,19 +205,26 @@ def test_update_delete_modality(self, mdata, axis): mdata.update() assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) - assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all() assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all() + assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all() + assert (getattr(mdata, attrm)["test"] == fulltestm[keptmask, :]).all() + assert (getattr(mdata, oattrm)["test"] == fullotestm[keptomask, :]).all() fullbatch = getattr(mdata, attr)["batch"] fullobatch = getattr(mdata, oattr)["batch"] + fulltestm = getattr(mdata, attrm)["test"] + fullotestm = getattr(mdata, oattrm)["test"] keptmask = getattr(mdata, f"{attr}map")[modnames[1]].reshape(-1) > 0 keptomask = getattr(mdata, f"{oattr}map")[modnames[1]].reshape(-1) > 0 + del mdata.mod[modnames[2]] mdata.update() assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all() assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all() + assert (getattr(mdata, attrm)["test"] == fulltestm[keptmask, :]).all() + assert (getattr(mdata, oattrm)["test"] == fullotestm[keptomask, :]).all() def test_update_intersecting(self, modalities, axis): """ @@ -220,7 +257,7 @@ def test_update_intersecting(self, modalities, axis): ) ).all() - # names along axis are intersected + # names along axis are unioned axisnames = reduce( lambda x, y: x.union(y, sort=False), (getattr(mod, f"{attr}_names") for mod in modalities.values()), @@ -228,14 +265,6 @@ def test_update_intersecting(self, modalities, axis): assert mdata.shape[axis] == axisnames.shape[0] assert (getattr(mdata, f"{attr}_names") == axisnames).all() - # Variables are different across modalities - for m, mod in modalities.items(): - # Columns are intact in individual modalities - assert "mod" in mod.obs.columns - assert all(mod.obs["mod"] == m) - assert "mod" in mod.var.columns - assert all(mod.var["mod"] == m) - def test_update_after_filter_obs_adata(self, mdata, axis): """ Check for muon issue #44. @@ -246,6 +275,14 @@ def test_update_after_filter_obs_adata(self, mdata, axis): old_obsnames = mdata.obs_names old_varnames = mdata.var_names + filtermask = mdata["mod3"].obs["min_count"] < -2 + fullfiltermask = mdata.obsmap["mod3"].copy() > 0 + fullfiltermask[fullfiltermask] = filtermask + keptmask = (mdata.obsmap["mod1"] > 0) | (mdata.obsmap["mod2"] > 0) | fullfiltermask + + some_obs_names = mdata[keptmask, :].obs_names.values[:2] + true_obsm_values = self.get_attrm_values(mdata[keptmask], "obs", "test", some_obs_names) + mdata.mod["mod3"] = mdata["mod3"][mdata["mod3"].obs["min_count"] < -2].copy() mdata.update() assert mdata.obs["batch"].isna().sum() == 0 @@ -255,28 +292,23 @@ def test_update_after_filter_obs_adata(self, mdata, axis): # check if the order is preserved assert (mdata.obs_names == old_obsnames[old_obsnames.isin(mdata.obs_names)]).all() + test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) + assert (true_obsm_values == test_obsm_values).all() + def test_update_after_obs_reordered(self, mdata): """ Update should work if obs are reordered. """ some_obs_names = mdata.obs_names.values[:2] - true_obsm_values = [ - mdata.obsm["test_obsm"][np.where(mdata.obs_names.values == name)[0][0]] - for name in some_obs_names - ] + true_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) mdata.mod["mod1"] = mdata["mod1"][::-1].copy() mdata.update() - test_obsm_values = [ - mdata.obsm["test_obsm"][np.where(mdata.obs_names == name)[0][0]] - for name in some_obs_names - ] + test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) - assert all( - [all(true_obsm_values[i] == test_obsm_values[i]) for i in range(len(true_obsm_values))] - ) + assert (true_obsm_values == test_obsm_values).all() @pytest.mark.usefixtures("filepath_h5mu") From 826c8cede04061da98011aacc6afb939e3ed275d Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 24 Oct 2025 15:20:42 +0200 Subject: [PATCH 07/14] add test for extremely duplicated names --- tests/test_update.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_update.py b/tests/test_update.py index 76117b4..eac828f 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -45,10 +45,12 @@ def modalities(request, obs_n, obs_across, obs_mod): mods["mod2"].var_names = varnames2 mods["mod3"].var_names = varnames3 elif obs_mod == "extreme_duplicated": # integer overflow: https://github.com/scverse/mudata/issues/107 - obs_names = mods["mod1"].obs_names.to_numpy() - obs_names[:-1] = obs_names[0] - mods["mod1"].obs_names = obs_names - + obsnames2 = mods["mod2"].obs_names.to_numpy() + varnames2 = mods["mod2"].var_names.to_numpy() + obsnames2[:-1] = obsnames2[0] = "testobs" + varnames2[:-1] = varnames2[0] = "testvar" + mods["mod2"].obs_names = obsnames2 + mods["mod2"].var_names = varnames2 return mods @@ -78,7 +80,7 @@ def mdata(modalities, axis): @pytest.mark.usefixtures("filepath_h5mu") @pytest.mark.parametrize("axis", [0, 1]) -@pytest.mark.parametrize("obs_mod", ["unique", "duplicated"]) +@pytest.mark.parametrize("obs_mod", ["unique", "duplicated", "extreme_duplicated"]) @pytest.mark.parametrize("obs_across", ["intersecting"]) @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) class TestMuData: From d8610cbe9d53720e175175346bfd317242b35501 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 24 Oct 2025 17:16:26 +0200 Subject: [PATCH 08/14] add tests for dtype --- tests/test_update.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/test_update.py b/tests/test_update.py index eac828f..818d47f 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -44,7 +44,9 @@ def modalities(request, obs_n, obs_across, obs_mod): mods["mod3"].obs_names = obsnames3 mods["mod2"].var_names = varnames2 mods["mod3"].var_names = varnames3 - elif obs_mod == "extreme_duplicated": # integer overflow: https://github.com/scverse/mudata/issues/107 + elif ( + obs_mod == "extreme_duplicated" + ): # integer overflow: https://github.com/scverse/mudata/issues/107 obsnames2 = mods["mod2"].obs_names.to_numpy() varnames2 = mods["mod2"].var_names.to_numpy() obsnames2[:-1] = obsnames2[0] = "testobs" @@ -107,6 +109,10 @@ def test_update_simple(self, mdata, axis): attr = "obs" if axis == 0 else "var" oattr = "var" if axis == 0 else "obs" + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" + # names along non-axis are concatenated assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) assert ( @@ -157,6 +163,10 @@ def test_update_add_modality(self, modalities, axis): mdata.mod[modnames[i]] = modalities[modnames[i]] mdata.update() + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" + test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) if axis == 1: assert np.isnan(mdata.obsm["test"]).sum() == modalities[modnames[i]].n_obs @@ -206,6 +216,10 @@ def test_update_delete_modality(self, mdata, axis): del mdata.mod[modnames[0]] mdata.update() + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all() assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all() @@ -249,6 +263,10 @@ def test_update_intersecting(self, modalities, axis): mdata = MuData(modalities, axis=axis) + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" + # names along non-axis are concatenated assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in modalities.values()) assert ( @@ -287,6 +305,11 @@ def test_update_after_filter_obs_adata(self, mdata, axis): mdata.mod["mod3"] = mdata["mod3"][mdata["mod3"].obs["min_count"] < -2].copy() mdata.update() + + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" + assert mdata.obs["batch"].isna().sum() == 0 assert (mdata.var_names == old_varnames).all() @@ -308,6 +331,10 @@ def test_update_after_obs_reordered(self, mdata): mdata.mod["mod1"] = mdata["mod1"][::-1].copy() mdata.update() + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" + test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) assert (true_obsm_values == test_obsm_values).all() From c1ba851248c217c92ce9844b68aab528d82337d5 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 3 Nov 2025 13:22:17 +0100 Subject: [PATCH 09/14] deduplicate processing of the map column --- src/mudata/_core/mudata.py | 22 +++++----------------- src/mudata/_core/utils.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index c868d51..17a428a 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -35,6 +35,7 @@ _maybe_coerce_to_int, _restore_index, _update_and_concat, + fix_attrmap_col, ) from .views import DictView @@ -643,14 +644,8 @@ def _update_attr( data_mod = pd.concat(dfs, join="outer", axis=1, sort=False) else: data_mod = _maybe_coerce_to_bool(pd.concat(dfs, join="outer", axis=0, sort=False)) - for mod, amod in self.mod.items(): - colname = mod + ":" + rowcol - # use 0 as special value for missing - # we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write - # also, this is compatible to Muon.jl - col = data_mod[colname] + 1 - col.replace(np.nan, 0, inplace=True) - data_mod[colname] = col.astype(np.uint32) + for mod in self.mod.keys(): + fix_attrmap_col(data_mod, mod, rowcol) data_mod = _make_index_unique(data_mod, force=attr_intersecting) data_global = _make_index_unique(data_global, force=attr_intersecting) @@ -689,16 +684,9 @@ def _update_attr( data_mod.index.set_names(rowcol, inplace=True) data_global.index.set_names(rowcol, inplace=True) for mod, amod in self.mod.items(): - colname = mod + ":" + rowcol - # use 0 as special value for missing - # we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write - # also, this is compatible to Muon.jl - col = data_mod.loc[:, colname] + 1 - col.replace(np.nan, 0, inplace=True) - col = col.astype(np.uint32) - data_mod[colname] = col + colname = fix_attrmap_col(data_mod, mod, rowcol) if mod in attrmap: - modmap = attrmap[mod].reshape(-1) + modmap = attrmap[mod].ravel() modmask = modmap > 0 # only use unchanged modalities for ordering if ( diff --git a/src/mudata/_core/utils.py b/src/mudata/_core/utils.py index 1cfc8b7..823c6e9 100644 --- a/src/mudata/_core/utils.py +++ b/src/mudata/_core/utils.py @@ -157,3 +157,14 @@ def _maybe_coerce_to_int(df: T) -> T: pass return df + + +def fix_attrmap_col(data_mod: pd.DataFrame, mod: str, rowcol: str) -> str: + colname = mod + ":" + rowcol + # use 0 as special value for missing + # we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write + # also, this is compatible to Muon.jl + col = data_mod[colname] + 1 + col.replace(np.nan, 0, inplace=True) + data_mod[colname] = col.astype(np.uint32) + return colname From 106a3622955e543a3497e0d4620622412419bdf6 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 3 Nov 2025 13:35:41 +0100 Subject: [PATCH 10/14] deduplicate reordering of data_mod and attrm --- src/mudata/_core/mudata.py | 42 ++++++++------------------------------ src/mudata/_core/utils.py | 29 +++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 17a428a..d6e3d47 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -35,7 +35,8 @@ _maybe_coerce_to_int, _restore_index, _update_and_concat, - fix_attrmap_col, + update_fix_attrmap_col, + update_reorder_df_and_attrm_index, ) from .views import DictView @@ -645,7 +646,7 @@ def _update_attr( else: data_mod = _maybe_coerce_to_bool(pd.concat(dfs, join="outer", axis=0, sort=False)) for mod in self.mod.keys(): - fix_attrmap_col(data_mod, mod, rowcol) + update_fix_attrmap_col(data_mod, mod, rowcol) data_mod = _make_index_unique(data_mod, force=attr_intersecting) data_global = _make_index_unique(data_global, force=attr_intersecting) @@ -653,19 +654,8 @@ def _update_attr( data_mod = data_mod.join(data_global, how="left", sort=False) if data_global.shape[0] > 0: - # reorder new index to conform to the old index as much as possible - kept_idx = data_global.index[data_global.index.isin(data_mod.index)] - new_idx = data_mod.index[~data_mod.index.isin(data_global.index)] - data_mod = data_mod.loc[kept_idx.append(new_idx), :] - - index_order = data_global.index.get_indexer(data_mod.index) - can_update = ( - new_idx.shape[0] == 0 # filtered or reordered - or kept_idx.shape[0] == data_global.shape[0] # new rows only - or data_mod.shape[0] - == data_global.shape[ - 0 - ] # renamed (since new_idx.shape[0] > 0 and kept_idx.shape[0] < data_global.shape[0]) + data_mod, index_order, can_update = update_reorder_df_and_attrm_index( + data_mod, data_global, axis, self.axis ) data_mod = _restore_index(data_mod) @@ -684,7 +674,7 @@ def _update_attr( data_mod.index.set_names(rowcol, inplace=True) data_global.index.set_names(rowcol, inplace=True) for mod, amod in self.mod.items(): - colname = fix_attrmap_col(data_mod, mod, rowcol) + colname = update_fix_attrmap_col(data_mod, mod, rowcol) if mod in attrmap: modmap = attrmap[mod].ravel() modmask = modmap > 0 @@ -717,24 +707,8 @@ def _update_attr( data_mod = _make_index_unique(data_mod, force=need_unique) data_mod = data_mod.join(data_global, how="left", sort=False) - # reorder new index to conform to the old index as much as possible - kept_idx = data_global.index[data_global.index.isin(data_mod.index)] - new_idx = data_mod.index[~data_mod.index.isin(data_global.index)] - data_mod = data_mod.loc[kept_idx.append(new_idx), :] - - index_order = data_global.index.get_indexer(data_mod.index) - can_update = ( - new_idx.shape[0] == 0 # filtered or reordered - or kept_idx.shape[0] == data_global.shape[0] # new rows only - or data_mod.shape[0] - == data_global.shape[ - 0 - ] # renamed (since new_idx.shape[0] > 0 and kept_idx.shape[0] < data_global.shape[0]) - or ( - axis == self._axis - and axis != -1 - and data_mod.shape[0] > data_global.shape[0] - ) # new modality added and concacenated + data_mod, index_order, can_update = update_reorder_df_and_attrm_index( + data_mod, data_global, axis, self.axis ) if need_unique: diff --git a/src/mudata/_core/utils.py b/src/mudata/_core/utils.py index 823c6e9..c615d93 100644 --- a/src/mudata/_core/utils.py +++ b/src/mudata/_core/utils.py @@ -159,7 +159,7 @@ def _maybe_coerce_to_int(df: T) -> T: return df -def fix_attrmap_col(data_mod: pd.DataFrame, mod: str, rowcol: str) -> str: +def update_fix_attrmap_col(data_mod: pd.DataFrame, mod: str, rowcol: str) -> str: colname = mod + ":" + rowcol # use 0 as special value for missing # we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write @@ -168,3 +168,30 @@ def fix_attrmap_col(data_mod: pd.DataFrame, mod: str, rowcol: str) -> str: col.replace(np.nan, 0, inplace=True) data_mod[colname] = col.astype(np.uint32) return colname + + +def update_reorder_df_and_attrm_index( + data_mod: pd.DataFrame, + data_global: pd.DataFrame, + axis: Literal[-1, 0, 1], + mdaxis: Literal[-1, 0, 1], +) -> tuple[pd.DataFrame, np.ndarray[np.intp], bool]: + # reorder new index to conform to the old index as much as possible + kept_idx = data_global.index[data_global.index.isin(data_mod.index)] + new_idx = data_mod.index[~data_mod.index.isin(data_global.index)] + data_mod = data_mod.loc[kept_idx.append(new_idx), :] + + index_order = data_global.index.get_indexer(data_mod.index) + can_update = ( + new_idx.shape[0] == 0 # filtered or reordered + or kept_idx.shape[0] == data_global.shape[0] # new rows only + or data_mod.shape[0] + == data_global.shape[ + 0 + ] # renamed (since new_idx.shape[0] > 0 and kept_idx.shape[0] < data_global.shape[0]) + or ( + axis == mdaxis and axis != -1 and data_mod.shape[0] > data_global.shape[0] + ) # new modality added and concacenated + ) + + return data_mod, index_order, can_update From 2d04cea774805ebd2ae9b342b8e6dde466c94240 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 3 Nov 2025 13:51:14 +0100 Subject: [PATCH 11/14] remove unnecessary boolean coercion --- src/mudata/_core/mudata.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index d6e3d47..a93ff41 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -628,8 +628,7 @@ def _update_attr( dfs = [ getattr(a, attr) .loc[:, []] - .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":") + .assign(**{f"{m}:{rowcol}": np.arange(getattr(a, attr).shape[0])}) for m, a in self.mod.items() ] @@ -641,10 +640,12 @@ def _update_attr( # Main case: no duplicates and no intersection if the axis is not shared if not attr_duplicated: # Shared axis - if axis == (1 - self._axis) or self._axis == -1: - data_mod = pd.concat(dfs, join="outer", axis=1, sort=False) - else: - data_mod = _maybe_coerce_to_bool(pd.concat(dfs, join="outer", axis=0, sort=False)) + data_mod = pd.concat( + dfs, + join="outer", + axis=1 if axis == (1 - self._axis) or self._axis == -1 else 0, + sort=False, + ) for mod in self.mod.keys(): update_fix_attrmap_col(data_mod, mod, rowcol) @@ -665,10 +666,12 @@ def _update_attr( # else: dfs = [_make_index_unique(df, force=True) for df in dfs] - if axis == (1 - self._axis) or self._axis == -1: - data_mod = pd.concat(dfs, join="outer", axis=1, sort=False) - else: - data_mod = _maybe_coerce_to_bool(pd.concat(dfs, join="outer", axis=0, sort=False)) + data_mod = pd.concat( + dfs, + join="outer", + axis=1 if axis == (1 - self._axis) or self._axis == -1 else 0, + sort=False, + ) data_mod = _restore_index(data_mod) data_mod.index.set_names(rowcol, inplace=True) From b0091b2f3384fefd9802269fddb504dfa11cdf5f Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 17 Nov 2025 09:26:31 +0100 Subject: [PATCH 12/14] use inner functions --- src/mudata/_core/mudata.py | 52 ++++++++++++++++++++++++++++++-------- src/mudata/_core/utils.py | 38 ---------------------------- 2 files changed, 42 insertions(+), 48 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index a93ff41..3d95b44 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -35,8 +35,6 @@ _maybe_coerce_to_int, _restore_index, _update_and_concat, - update_fix_attrmap_col, - update_reorder_df_and_attrm_index, ) from .views import DictView @@ -634,6 +632,42 @@ def _update_attr( index_order = None can_update = True + + def fix_attrmap_col(data_mod: pd.DataFrame, mod: str, rowcol: str) -> str: + colname = mod + ":" + rowcol + # use 0 as special value for missing + # we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write + # also, this is compatible to Muon.jl + col = data_mod[colname] + 1 + col.replace(np.nan, 0, inplace=True) + data_mod[colname] = col.astype(np.uint32) + return colname + + kept_idx: Any + new_idx: Any + + def reorder_data_mod(): + nonlocal kept_idx, new_idx, data_mod + # reorder new index to conform to the old index as much as possible + kept_idx = data_global.index[data_global.index.isin(data_mod.index)] + new_idx = data_mod.index[~data_mod.index.isin(data_global.index)] + data_mod = data_mod.loc[kept_idx.append(new_idx), :] + + def calc_attrm_update(): + nonlocal index_order, can_update + index_order = data_global.index.get_indexer(data_mod.index) + can_update = ( + new_idx.shape[0] == 0 # filtered or reordered + or kept_idx.shape[0] == data_global.shape[0] # new rows only + or data_mod.shape[0] + == data_global.shape[ + 0 + ] # renamed (since new_idx.shape[0] > 0 and kept_idx.shape[0] < data_global.shape[0]) + or ( + axis == self.axis and axis != -1 and data_mod.shape[0] > data_global.shape[0] + ) # new modality added and concacenated + ) + # # Join modality .obs/.var tables # @@ -647,7 +681,7 @@ def _update_attr( sort=False, ) for mod in self.mod.keys(): - update_fix_attrmap_col(data_mod, mod, rowcol) + fix_attrmap_col(data_mod, mod, rowcol) data_mod = _make_index_unique(data_mod, force=attr_intersecting) data_global = _make_index_unique(data_global, force=attr_intersecting) @@ -655,9 +689,8 @@ def _update_attr( data_mod = data_mod.join(data_global, how="left", sort=False) if data_global.shape[0] > 0: - data_mod, index_order, can_update = update_reorder_df_and_attrm_index( - data_mod, data_global, axis, self.axis - ) + reorder_data_mod() + calc_attrm_update() data_mod = _restore_index(data_mod) data_global = _restore_index(data_global) @@ -677,7 +710,7 @@ def _update_attr( data_mod.index.set_names(rowcol, inplace=True) data_global.index.set_names(rowcol, inplace=True) for mod, amod in self.mod.items(): - colname = update_fix_attrmap_col(data_mod, mod, rowcol) + colname = fix_attrmap_col(data_mod, mod, rowcol) if mod in attrmap: modmap = attrmap[mod].ravel() modmask = modmap > 0 @@ -710,9 +743,8 @@ def _update_attr( data_mod = _make_index_unique(data_mod, force=need_unique) data_mod = data_mod.join(data_global, how="left", sort=False) - data_mod, index_order, can_update = update_reorder_df_and_attrm_index( - data_mod, data_global, axis, self.axis - ) + reorder_data_mod() + calc_attrm_update() if need_unique: data_mod = _restore_index(data_mod) diff --git a/src/mudata/_core/utils.py b/src/mudata/_core/utils.py index c615d93..1cfc8b7 100644 --- a/src/mudata/_core/utils.py +++ b/src/mudata/_core/utils.py @@ -157,41 +157,3 @@ def _maybe_coerce_to_int(df: T) -> T: pass return df - - -def update_fix_attrmap_col(data_mod: pd.DataFrame, mod: str, rowcol: str) -> str: - colname = mod + ":" + rowcol - # use 0 as special value for missing - # we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write - # also, this is compatible to Muon.jl - col = data_mod[colname] + 1 - col.replace(np.nan, 0, inplace=True) - data_mod[colname] = col.astype(np.uint32) - return colname - - -def update_reorder_df_and_attrm_index( - data_mod: pd.DataFrame, - data_global: pd.DataFrame, - axis: Literal[-1, 0, 1], - mdaxis: Literal[-1, 0, 1], -) -> tuple[pd.DataFrame, np.ndarray[np.intp], bool]: - # reorder new index to conform to the old index as much as possible - kept_idx = data_global.index[data_global.index.isin(data_mod.index)] - new_idx = data_mod.index[~data_mod.index.isin(data_global.index)] - data_mod = data_mod.loc[kept_idx.append(new_idx), :] - - index_order = data_global.index.get_indexer(data_mod.index) - can_update = ( - new_idx.shape[0] == 0 # filtered or reordered - or kept_idx.shape[0] == data_global.shape[0] # new rows only - or data_mod.shape[0] - == data_global.shape[ - 0 - ] # renamed (since new_idx.shape[0] > 0 and kept_idx.shape[0] < data_global.shape[0]) - or ( - axis == mdaxis and axis != -1 and data_mod.shape[0] > data_global.shape[0] - ) # new modality added and concacenated - ) - - return data_mod, index_order, can_update From dd212ed99f116472c1c8515e36786d95f870c0b2 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 17 Nov 2025 09:31:49 +0100 Subject: [PATCH 13/14] silence ruff false positives --- src/mudata/_core/mudata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 3d95b44..328010f 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -657,8 +657,8 @@ def calc_attrm_update(): nonlocal index_order, can_update index_order = data_global.index.get_indexer(data_mod.index) can_update = ( - new_idx.shape[0] == 0 # filtered or reordered - or kept_idx.shape[0] == data_global.shape[0] # new rows only + new_idx.shape[0] == 0 # noqa: F821 filtered or reordered + or kept_idx.shape[0] == data_global.shape[0] # noqa: F821 new rows only or data_mod.shape[0] == data_global.shape[ 0 From 826c2b93edc64e09616e6523d59068e29a3960f3 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Tue, 9 Dec 2025 17:31:17 +0100 Subject: [PATCH 14/14] minor simplification --- src/mudata/_core/utils.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/mudata/_core/utils.py b/src/mudata/_core/utils.py index 1cfc8b7..41a9f45 100644 --- a/src/mudata/_core/utils.py +++ b/src/mudata/_core/utils.py @@ -9,23 +9,23 @@ def _make_index_unique(df: pd.DataFrame, force: bool = False) -> pd.DataFrame: - if force or not df.index.is_unique: - dup_idx = np.zeros((df.shape[0],), dtype=np.uint8) - duplicates = np.nonzero(df.index.duplicated())[0] - cnt = Counter() - for dup in duplicates: - idxval = df.index[dup] - newval = cnt[idxval] + 1 - try: - dup_idx[dup] = newval - except OverflowError: - dup_idx = dup_idx.astype(np.min_scalar_type(newval)) - dup_idx[dup] = newval - cnt[idxval] = newval - return df.set_index(dup_idx, append=True) - else: + if not force and df.index.is_unique: return df + dup_idx = np.zeros((df.shape[0],), dtype=np.uint8) + duplicates = np.nonzero(df.index.duplicated())[0] + cnt = Counter() + for dup in duplicates: + idxval = df.index[dup] + newval = cnt[idxval] + 1 + try: + dup_idx[dup] = newval + except OverflowError: + dup_idx = dup_idx.astype(np.min_scalar_type(newval)) + dup_idx[dup] = newval + cnt[idxval] = newval + return df.set_index(dup_idx, append=True) + def _restore_index(df: pd.DataFrame) -> pd.DataFrame: return df.reset_index(level=-1, drop=True) if df.index.nlevels > 1 else df