diff --git a/pyproject.toml b/pyproject.toml index 1e2b5c5..0f0a29d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ ] dependencies = [ "anndata >= 0.10.8", + "pandas >= 1.4.0" ] dynamic = ["version"] diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index f5b628a..328010f 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -452,6 +452,8 @@ def _check_changed_attr_names(self, attr: str, columns: bool = False): attr_names_changed, attr_columns_changed = False, False if not hasattr(self, attrhash): attr_names_changed, attr_columns_changed = True, True + elif len(self.mod) < len(getattr(self, attrhash)): + attr_names_changed, attr_columns_changed = True, None else: for m in self.mod.keys(): if m in getattr(self, attrhash): @@ -599,32 +601,20 @@ def _update_attr( self._update_attr_legacy(attr, axis, join_common, **kwargs) return - prev_index = getattr(self, attr).index - # No _attrhash when upon read # No _attrhash in mudata < 0.2.0 _attrhash = f"_{attr}hash" attr_changed = self._check_changed_attr_names(attr) - attr_duplicated = self._check_duplicated_attr_names(attr) - attr_intersecting = self._check_intersecting_attr_names(attr) - - if attr_duplicated: - warnings.warn( - f"{attr}_names are not unique. To make them unique, call `.{attr}_names_make_unique`.", - stacklevel=2, - ) - if self._axis == -1: - warnings.warn( - f"Behaviour is not defined with axis=-1, {attr}_names need to be made unique first.", - stacklevel=2, - ) - if not any(attr_changed): # Nothing to update return data_global = getattr(self, attr) + prev_index = data_global.index + + attr_duplicated = not data_global.index.is_unique or self._check_duplicated_attr_names(attr) + attr_intersecting = self._check_intersecting_attr_names(attr) # Generate unique colnames (rowcol,) = self._find_unique_colnames(attr, 1) @@ -633,141 +623,109 @@ def _update_attr( attrp = getattr(self, attr + "p") attrmap = getattr(self, attr + "map") - # TODO: take advantage when attr_changed[0] == False — only new columns to be added + dfs = [ + getattr(a, attr) + .loc[:, []] + .assign(**{f"{m}:{rowcol}": np.arange(getattr(a, attr).shape[0])}) + for m, a in self.mod.items() + ] + + index_order = None + can_update = True + + def fix_attrmap_col(data_mod: pd.DataFrame, mod: str, rowcol: str) -> str: + colname = mod + ":" + rowcol + # use 0 as special value for missing + # we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write + # also, this is compatible to Muon.jl + col = data_mod[colname] + 1 + col.replace(np.nan, 0, inplace=True) + data_mod[colname] = col.astype(np.uint32) + return colname + + kept_idx: Any + new_idx: Any + + def reorder_data_mod(): + nonlocal kept_idx, new_idx, data_mod + # reorder new index to conform to the old index as much as possible + kept_idx = data_global.index[data_global.index.isin(data_mod.index)] + new_idx = data_mod.index[~data_mod.index.isin(data_global.index)] + data_mod = data_mod.loc[kept_idx.append(new_idx), :] + + def calc_attrm_update(): + nonlocal index_order, can_update + index_order = data_global.index.get_indexer(data_mod.index) + can_update = ( + new_idx.shape[0] == 0 # noqa: F821 filtered or reordered + or kept_idx.shape[0] == data_global.shape[0] # noqa: F821 new rows only + or data_mod.shape[0] + == data_global.shape[ + 0 + ] # renamed (since new_idx.shape[0] > 0 and kept_idx.shape[0] < data_global.shape[0]) + or ( + axis == self.axis and axis != -1 and data_mod.shape[0] > data_global.shape[0] + ) # new modality added and concacenated + ) # # Join modality .obs/.var tables # # Main case: no duplicates and no intersection if the axis is not shared - # if not attr_duplicated: # Shared axis - if axis == (1 - self._axis) or self._axis == -1: - # We assume attr_intersecting and can't join_common - data_mod = pd.concat( - [ - getattr(a, attr) - .loc[:, []] - .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":") - for m, a in self.mod.items() - ], - join="outer", - axis=1, - sort=False, - ) - else: - data_mod = _maybe_coerce_to_bool( - pd.concat( - [ - getattr(a, attr) - .loc[:, []] - .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":") - for m, a in self.mod.items() - ], - join="outer", - axis=0, - sort=False, - ) - ) - + data_mod = pd.concat( + dfs, + join="outer", + axis=1 if axis == (1 - self._axis) or self._axis == -1 else 0, + sort=False, + ) for mod in self.mod.keys(): - colname = mod + ":" + rowcol - # use 0 as special value for missing - # we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write - # also, this is compatible to Muon.jl - col = data_mod[colname] + 1 - col.replace(np.nan, 0, inplace=True) - data_mod[colname] = col.astype(np.uint32) - - if len(data_global.columns) > 0: - # TODO: if there were intersecting attrnames between modalities, - # this will increase the size of the index - # Should we use attrmap to figure the index out? - # - if not attr_intersecting: - data_mod = data_mod.join(data_global, how="left", sort=False) - else: - # In order to preserve the order of the index, instead, - # perform a join based on (index, attrmap value) pairs. - (col_index,) = self._find_unique_colnames(attr, 1) - data_mod = data_mod.rename_axis(col_index, axis=0).reset_index() + fix_attrmap_col(data_mod, mod, rowcol) - data_global = data_global.rename_axis(col_index, axis=0).reset_index() - for mod in self.mod.keys(): - # Only add mapping for modalities that exist in attrmap - if mod in getattr(self, attr + "map"): - data_global[mod + ":" + rowcol] = getattr(self, attr + "map")[mod] - attrmap_columns = [ - mod + ":" + rowcol - for mod in self.mod.keys() - if mod in getattr(self, attr + "map") - ] + data_mod = _make_index_unique(data_mod, force=attr_intersecting) + data_global = _make_index_unique(data_global, force=attr_intersecting) + if data_global.shape[1] > 0: + data_mod = data_mod.join(data_global, how="left", sort=False) - data_mod = data_mod.merge( - data_global, on=[col_index, *attrmap_columns], how="left", sort=False - ) + if data_global.shape[0] > 0: + reorder_data_mod() + calc_attrm_update() - # Restore the index and remove the helper column - data_mod = data_mod.set_index(col_index).rename_axis(None, axis=0) - data_global = data_global.set_index(col_index).rename_axis(None, axis=0) + data_mod = _restore_index(data_mod) + data_global = _restore_index(data_global) # # General case: with duplicates and/or intersections # else: - dfs = [ - _make_index_unique( - getattr(a, attr) - .loc[:, []] - .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":") - ) - for m, a in self.mod.items() - ] + dfs = [_make_index_unique(df, force=True) for df in dfs] data_mod = pd.concat( dfs, join="outer", - axis=axis, + axis=1 if axis == (1 - self._axis) or self._axis == -1 else 0, sort=False, ) - # pd.concat wrecks the ordering when doing an outer join with a MultiIndex and different data frame shapes - if axis == 1: - newidx = ( - reduce(lambda x, y: x.union(y, sort=False), (df.index for df in dfs)) - .to_frame() - .reset_index(level=1, drop=True) - ) - globalidx = data_global.index.get_level_values(0) - mask = globalidx.isin(newidx.iloc[:, 0]) - if len(mask) > 0: - negativemask = ~newidx.index.get_level_values(0).isin(globalidx) - newidx = pd.MultiIndex.from_frame( - pd.concat( - [newidx.loc[globalidx[mask], :], newidx.iloc[negativemask, :]], axis=0 - ) - ) - data_mod = data_mod.reindex(newidx, copy=False) - data_mod = _restore_index(data_mod) data_mod.index.set_names(rowcol, inplace=True) data_global.index.set_names(rowcol, inplace=True) for mod, amod in self.mod.items(): - colname = mod + ":" + rowcol - # use 0 as special value for missing - # we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write - # also, this is compatible to Muon.jl - col = data_mod.loc[:, colname] + 1 - col.replace(np.nan, 0, inplace=True) - col = col.astype(np.uint32) - data_mod.loc[:, colname] = col - data_mod.set_index(colname, append=True, inplace=True) - if mod in attrmap and np.sum(attrmap[mod] > 0) == getattr(amod, attr).shape[0]: - data_global.set_index(attrmap[mod], append=True, inplace=True) - data_global.index.set_names(colname, level=-1, inplace=True) - - if len(data_global) > 0: + colname = fix_attrmap_col(data_mod, mod, rowcol) + if mod in attrmap: + modmap = attrmap[mod].ravel() + modmask = modmap > 0 + # only use unchanged modalities for ordering + if ( + modmask.sum() == getattr(amod, attr).shape[0] + and ( + getattr(amod, attr).index[modmap[modmask] - 1] == prev_index[modmask] + ).all() + ): + data_mod.set_index(colname, append=True, inplace=True) + data_global.set_index(attrmap[mod].reshape(-1), append=True, inplace=True) + data_global.index.set_names(colname, level=-1, inplace=True) + + if data_global.shape[0] > 0: if not data_global.index.is_unique: warnings.warn( f"{attr}_names is not unique, global {attr} is present, and {attr}map is empty. The update() is not well-defined, verify if global {attr} map to the correct modality-specific {attr}.", @@ -776,65 +734,47 @@ def _update_attr( data_mod.reset_index( data_mod.index.names.difference(data_global.index.names), inplace=True ) - data_mod = _make_index_unique(data_mod) - data_global = _make_index_unique(data_global) + # after inserting a new modality with duplicates, but no duplicates before: + # data_mod.index is not unique + # after deleting a modality with duplicates: data_global.index is not unique, but + # data_mod.index is unique + need_unique = data_mod.index.is_unique | data_global.index.is_unique + data_global = _make_index_unique(data_global, force=need_unique) + data_mod = _make_index_unique(data_mod, force=need_unique) data_mod = data_mod.join(data_global, how="left", sort=False) + + reorder_data_mod() + calc_attrm_update() + + if need_unique: + data_mod = _restore_index(data_mod) + data_global = _restore_index(data_global) + data_mod.reset_index(level=list(range(1, data_mod.index.nlevels)), inplace=True) + data_global.reset_index(level=list(range(1, data_global.index.nlevels)), inplace=True) data_mod.index.set_names(None, inplace=True) + data_global.index.set_names(None, inplace=True) # get adata positions and remove columns from the data frame mdict = {} for m in self.mod.keys(): colname = m + ":" + rowcol mdict[m] = data_mod[colname].to_numpy() - # data_mod.drop(colname, axis=1, inplace=True) - - # Add data from global .obs/.var columns # This might reduce the size of .obs/.var if observations/variables were removed - if getattr(self, attr).index.is_unique: - # There are no new values in the index - # Original index is present in data_global - attr_reindexed = getattr(self, attr).reindex(index=data_mod.index, copy=False) - else: - # Reindexing won't work with duplicated labels: - # cannot reindex on an axis with duplicate labels. - # Use attrmap to resolve it. - - # TODO: might be possible to refactor to memoize it - # if it has already been done in the same ._update_attr() - col_index, col_range = self._find_unique_colnames(attr, 2) - - # copy is made here - data_mod = data_mod.rename_axis(col_index, axis=0).reset_index() - - data_global[col_range] = np.arange(len(data_global)) - - for mod in self.mod.keys(): - if mod in getattr(self, attr + "map"): - data_global[mod + ":" + rowcol] = getattr(self, attr + "map")[mod] - attrmap_columns = [ - mod + ":" + rowcol for mod in self.mod.keys() if mod in getattr(self, attr + "map") - ] - - data_mod = data_mod.merge(data_global, on=attrmap_columns, how="left", sort=False) - - index_selection = data_mod[col_range].values - - data_mod.drop(col_range, axis=1, inplace=True) - data_global.drop(col_range, axis=1, inplace=True) - - # Restore the index and remove the helper column - data_mod = data_mod.set_index(col_index).rename_axis(None, axis=0) - attr_reindexed = getattr(self, attr).iloc[index_selection] - attr_reindexed.index = data_mod.index + data_mod.drop(colname, axis=1, inplace=True) - # Clean up - for colname in (mod + "+" + rowcol for mod in self.mod.keys()): - data_mod.drop(colname, axis=1, inplace=True, errors="ignore") + if not data_mod.index.is_unique: + warnings.warn( + f"{attr}_names are not unique. To make them unique, call `.{attr}_names_make_unique`." + ) + if self._axis == -1: + warnings.warn( + f"Behaviour is not defined with axis=-1, {attr}_names need to be made unique first." + ) setattr( self, "_" + attr, - attr_reindexed, + data_mod, ) # Update .obsm/.varm @@ -844,51 +784,18 @@ def _update_attr( for mod, mapping in mdict.items(): attrm[mod] = mapping > 0 - now_index = getattr(self, attr).index - - if len(prev_index) == 0: - # New object - pass - elif now_index.equals(prev_index): - # Index is the same - pass - else: - keep_index = prev_index.isin(now_index) - new_index = ~now_index.isin(prev_index) - - if new_index.sum() == 0 or ( - keep_index.sum() + new_index.sum() == len(now_index) - and len(now_index) > len(prev_index) - ): - # Another length (filtered) or new modality added - # Update .obsm/.varm (size might have changed) - # NOTE: .get_index doesn't work with duplicated indices - if any(prev_index.duplicated()): - # Assume the relative order of duplicates hasn't changed - # NOTE: .get_loc() for each element is too slow - # We will rename duplicated in prev_index and now_index - # in order to use .get_indexer - # index_order = [ - # prev_index.get_loc(i) if i in prev_index else -1 for i in now_index - # ] - prev_values = prev_index.values.copy() - now_values = now_index.values.copy() - for value in prev_index[np.where(prev_index.duplicated())[0]]: - v_now = np.where(now_index == value)[0] - v_prev = np.where(prev_index.get_loc(value))[0] - for i in range(min(len(v_now), len(v_prev))): - prev_values[v_prev[i]] = f"{str(value)}-{i}" - now_values[v_now[i]] = f"{str(value)}-{i}" - - prev_index = pd.Index(prev_values) - now_index = pd.Index(now_values) - - index_order = prev_index.get_indexer(now_index) - - for mx_key in attrm.keys(): + if index_order is not None: + if can_update: + for mx_key, mx in attrm.items(): if mx_key not in self.mod.keys(): # not a modality name - attrm[mx_key] = attrm[mx_key][index_order] - attrm[mx_key][index_order == -1] = np.nan + cattr = attrm[mx_key] + if isinstance(cattr, pd.DataFrame): + cattr = cattr.iloc[index_order, :] + cattr.iloc[index_order == -1, :] = pd.NA + else: + cattr = cattr[index_order] + cattr[index_order == -1] = np.nan + attrm[mx_key] = cattr # Update .obsp/.varp (size might have changed) for mx_key in attrp.keys(): @@ -896,11 +803,6 @@ def _update_attr( attrp[mx_key][index_order == -1, :] = -1 attrp[mx_key][:, index_order == -1] = -1 - elif len(now_index) == len(prev_index): - # Renamed since new_index.sum() != 0 - # We have to assume the order hasn't changed - pass - else: raise NotImplementedError( f"{attr}_names seem to have been renamed and filtered at the same time. " @@ -1133,7 +1035,8 @@ def _update_attr_legacy( getattr(a, attr) .drop(columns_common, axis=1) .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":") + .add_prefix(m + ":"), + force=True, ) ) for m, a in self.mod.items() @@ -1150,7 +1053,7 @@ def _update_attr_legacy( data_common = pd.concat( [ _maybe_coerce_to_boolean( - _make_index_unique(getattr(a, attr)[columns_common]) + _make_index_unique(getattr(a, attr)[columns_common], force=True) ) for m, a in self.mod.items() ], @@ -1166,7 +1069,8 @@ def _update_attr_legacy( _make_index_unique( getattr(a, attr) .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":") + .add_prefix(m + ":"), + force=True, ) for m, a in self.mod.items() ] @@ -1228,8 +1132,8 @@ def _update_attr_legacy( data_mod.reset_index( data_mod.index.names.difference(data_global.index.names), inplace=True ) - data_mod = _make_index_unique(data_mod) - data_global = _make_index_unique(data_global) + data_mod = _make_index_unique(data_mod, force=True) + data_global = _make_index_unique(data_global, force=True) data_mod = data_mod.join(data_global, how="left", sort=False) data_mod.reset_index(level=list(range(1, data_mod.index.nlevels)), inplace=True) data_mod.index.set_names(None, inplace=True) diff --git a/src/mudata/_core/utils.py b/src/mudata/_core/utils.py index f3daa95..41a9f45 100644 --- a/src/mudata/_core/utils.py +++ b/src/mudata/_core/utils.py @@ -8,25 +8,27 @@ T = TypeVar("T", pd.Series, pd.DataFrame) -def _make_index_unique(df: pd.DataFrame) -> pd.DataFrame: +def _make_index_unique(df: pd.DataFrame, force: bool = False) -> pd.DataFrame: + if not force and df.index.is_unique: + return df + dup_idx = np.zeros((df.shape[0],), dtype=np.uint8) - if not df.index.is_unique: - duplicates = np.nonzero(df.index.duplicated())[0] - cnt = Counter() - for dup in duplicates: - idxval = df.index[dup] - newval = cnt[idxval] + 1 - try: - dup_idx[dup] = newval - except OverflowError: - dup_idx = dup_idx.astype(np.min_scalar_type(newval)) - dup_idx[dup] = newval - cnt[idxval] = newval + duplicates = np.nonzero(df.index.duplicated())[0] + cnt = Counter() + for dup in duplicates: + idxval = df.index[dup] + newval = cnt[idxval] + 1 + try: + dup_idx[dup] = newval + except OverflowError: + dup_idx = dup_idx.astype(np.min_scalar_type(newval)) + dup_idx[dup] = newval + cnt[idxval] = newval return df.set_index(dup_idx, append=True) def _restore_index(df: pd.DataFrame) -> pd.DataFrame: - return df.reset_index(level=-1, drop=True) + return df.reset_index(level=-1, drop=True) if df.index.nlevels > 1 else df def _maybe_coerce_to_boolean(df: T) -> T: diff --git a/tests/test_update.py b/tests/test_update.py index ea27d9a..818d47f 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -1,32 +1,32 @@ import unittest +from functools import reduce import numpy as np import pytest from anndata import AnnData -from mudata import MuData +from mudata import MuData, set_options @pytest.fixture() -def mdata(request, obs_n, obs_across, obs_mod): - # Generate unique, intersecting, and joint observations by default +def modalities(request, obs_n, obs_across, obs_mod): + n_mod = 3 + mods = dict() np.random.seed(100) - mod1 = AnnData(X=np.random.normal(size=3000).reshape(-1, 10)) - mod2 = AnnData(X=np.random.normal(size=1000).reshape(-1, 10)) - - mods = {"mod1": mod1, "mod2": mod2} - # Make var_names different in different modalities - for m in ["mod1", "mod2"]: - mods[m].obs_names = [f"obs{i}" for i in range(mods[m].n_obs)] - mods[m].var_names = [f"{m}_var{i}" for i in range(mods[m].n_vars)] + for i in range(n_mod): + i1 = i + 1 + m = f"mod{i1}" + mods[m] = AnnData(X=np.random.normal(size=3000 * i1).reshape(-1, 10 * i1)) + mods[m].obs["mod"] = m + mods[m].var["mod"] = m mods[m].obs["min_count"] = mods[m].X.min(axis=1) if obs_n: if obs_n == "disjoint": mod2_which_obs = np.random.choice( - mods["mod2"].obs_names, size=mods["mod2"].n_obs // 2, replace=False + mods["mod1"].obs_names, size=mods["mod1"].n_obs // 2, replace=False ) - mods["mod2"] = mods["mod2"][mod2_which_obs].copy() + mods["mod1"] = mods["mod1"][mod2_which_obs].copy() if obs_across: if obs_across != "intersecting": @@ -34,19 +34,32 @@ def mdata(request, obs_n, obs_across, obs_mod): if obs_mod: if obs_mod == "duplicated": - for m in ["mod1", "mod2"]: - # Index does not support mutable operations - obs_names = mods[m].obs_names.to_numpy() - obs_names[1] = obs_names[0] - mods[m].obs_names = obs_names + obsnames2 = mods["mod2"].obs_names.to_numpy() + obsnames3 = mods["mod3"].obs_names.to_numpy() + varnames2 = mods["mod2"].var_names.to_numpy() + varnames3 = mods["mod3"].var_names.to_numpy() + obsnames2[0] = obsnames2[1] = obsnames3[1] = "testobs" + varnames2[0] = varnames2[1] = varnames3[1] = "testvar" + mods["mod2"].obs_names = obsnames2 + mods["mod3"].obs_names = obsnames3 + mods["mod2"].var_names = varnames2 + mods["mod3"].var_names = varnames3 elif ( obs_mod == "extreme_duplicated" ): # integer overflow: https://github.com/scverse/mudata/issues/107 - obs_names = mods["mod1"].obs_names.to_numpy() - obs_names[:-1] = obs_names[0] - mods["mod1"].obs_names = obs_names + obsnames2 = mods["mod2"].obs_names.to_numpy() + varnames2 = mods["mod2"].var_names.to_numpy() + obsnames2[:-1] = obsnames2[0] = "testobs" + varnames2[:-1] = varnames2[0] = "testvar" + mods["mod2"].obs_names = obsnames2 + mods["mod2"].var_names = varnames2 + + return mods - mdata = MuData(mods) + +@pytest.fixture() +def mdata_legacy(modalities): + mdata = MuData(modalities) batches = np.random.choice(["a", "b", "c"], size=mdata.shape[0], replace=True) mdata.obs["batch"] = batches @@ -55,46 +68,281 @@ def mdata(request, obs_n, obs_across, obs_mod): @pytest.fixture() -def modalities(request, obs_n, obs_across, obs_mod): - n_mod = 3 - mods = dict() - np.random.seed(100) - for i in range(n_mod): - i1 = i + 1 - m = f"mod{i1}" - mods[m] = AnnData(X=np.random.normal(size=3000 * i1).reshape(-1, 10 * i1)) - mods[m].obs["mod"] = m - mods[m].var["mod"] = m +def mdata(modalities, axis): + md = MuData(modalities, axis=axis) - if obs_n: - if obs_n == "disjoint": - mod2_which_obs = np.random.choice( - mods["mod2"].obs_names, size=mods["mod2"].n_obs // 2, replace=False + md.obs["batch"] = np.random.choice(["a", "b", "c"], size=md.shape[0], replace=True) + md.var["batch"] = np.random.choice(["d", "e", "f"], size=md.shape[1], replace=True) + + md.obsm["test"] = np.random.normal(size=(md.n_obs, 2)) + md.varm["test"] = np.random.normal(size=(md.n_var, 2)) + + return md + + +@pytest.mark.usefixtures("filepath_h5mu") +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("obs_mod", ["unique", "duplicated", "extreme_duplicated"]) +@pytest.mark.parametrize("obs_across", ["intersecting"]) +@pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) +class TestMuData: + @pytest.fixture(autouse=True) + def new_update(self): + set_options(pull_on_update=False) + yield + set_options(pull_on_update=None) + + @staticmethod + def get_attrm_values(mdata, attr, key, names): + attrm = getattr(mdata, f"{attr}m") + index = getattr(mdata, f"{attr}_names") + return np.concatenate( + [np.atleast_1d(attrm[key][np.nonzero(index == name)[0]]) for name in names] + ) + + def test_update_simple(self, mdata, axis): + """ + Update should work when + - obs_names are the same across modalities, + - var_names are unique to each modality + """ + attr = "obs" if axis == 0 else "var" + oattr = "var" if axis == 0 else "obs" + + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" + + # names along non-axis are concatenated + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) + assert ( + getattr(mdata, f"{oattr}_names") + == reduce( + lambda x, y: x.append(y), + (getattr(mod, f"{oattr}_names") for mod in mdata.mod.values()), ) - mods["mod2"] = mods["mod2"][mod2_which_obs].copy() + ).all() - if obs_across: - if obs_across != "intersecting": - raise NotImplementedError("Tests for non-intersecting obs_names are not implemented") + # names along axis are unioned + axisnames = reduce( + lambda x, y: x.union(y, sort=False), + (getattr(mod, f"{attr}_names") for mod in mdata.mod.values()), + ) + assert mdata.shape[axis] == axisnames.shape[0] + assert (getattr(mdata, f"{attr}_names").sort_values() == axisnames.sort_values()).all() + + # guards against Pandas scrambling the order. This was the case for pandas < 1.4.0 when using pd.concat with an outer join on a MultiIndex. + # reprex: + # + # import numpy as np + # import pandas as pd + # df1 = pd.DataFrame({"a": np.repeat(np.arange(5), 2), "b": np.tile(np.asarray([0, 1]), 5), "c": np.arange(10)}).set_index("a").set_index("b", append=True) + # df2 = pd.DataFrame({"a": np.repeat(np.arange(10), 2), "b": np.tile(np.asarray([0, 1]), 10), "d": np.arange(20)}).set_index("a").set_index("b", append=True) + # df1 = df1.iloc[::-1, :] + # df = pd.concat((kdf1, df2), axis=1, join="outer", sort=False) + assert ( + getattr(mdata, f"{attr}_names")[: mdata["mod1"].shape[axis]] + == getattr(mdata["mod1"], f"{attr}_names") + ).all() + + def test_update_add_modality(self, modalities, axis): + modnames = list(modalities.keys()) + mdata = MuData({modname: modalities[modname] for modname in modnames[:-2]}, axis=axis) + + attr = "obs" if axis == 0 else "var" + oattr = "var" if axis == 0 else "obs" + + for i in (-2, -1): + old_attrnames = getattr(mdata, f"{attr}_names") + old_oattrnames = getattr(mdata, f"{oattr}_names") + + some_obs_names = mdata.obs_names[:2] + mdata.obsm["test"] = np.random.normal(size=(mdata.n_obs, 1)) + true_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) + + mdata.mod[modnames[i]] = modalities[modnames[i]] + mdata.update() + + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" + + test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) + if axis == 1: + assert np.isnan(mdata.obsm["test"]).sum() == modalities[modnames[i]].n_obs + assert np.all(np.isnan(mdata.obsm["test"][-modalities[modnames[i]].n_obs :])) + assert np.all(~np.isnan(mdata.obsm["test"][: -modalities[modnames[i]].n_obs])) + assert ( + test_obsm_values[~np.isnan(test_obsm_values)].reshape(-1) + == true_obsm_values.reshape(-1) + ).all() + else: + assert (test_obsm_values == true_obsm_values).all() + + attrnames = getattr(mdata, f"{attr}_names") + oattrnames = getattr(mdata, f"{oattr}_names") + assert (attrnames[: old_attrnames.size] == old_attrnames).all() + assert (oattrnames[: old_oattrnames.size] == old_oattrnames).all() + + assert ( + attrnames + == old_attrnames.union( + getattr(modalities[modnames[i]], f"{attr}_names"), sort=False + ) + ).all() + assert ( + oattrnames + == old_oattrnames.append(getattr(modalities[modnames[i]], f"{oattr}_names")) + ).all() + + def test_update_delete_modality(self, mdata, axis): + modnames = list(mdata.mod.keys()) + attr = "obs" if axis == 0 else "var" + oattr = "var" if axis == 0 else "obs" + attrm = f"{attr}m" + oattrm = f"{oattr}m" + + fullbatch = getattr(mdata, attr)["batch"] + fullobatch = getattr(mdata, oattr)["batch"] + fulltestm = getattr(mdata, attrm)["test"] + fullotestm = getattr(mdata, oattrm)["test"] + keptmask = (getattr(mdata, f"{attr}map")[modnames[1]].reshape(-1) > 0) | ( + getattr(mdata, f"{attr}map")[modnames[2]].reshape(-1) > 0 + ) + keptomask = (getattr(mdata, f"{oattr}map")[modnames[1]].reshape(-1) > 0) | ( + getattr(mdata, f"{oattr}map")[modnames[2]].reshape(-1) > 0 + ) - if obs_mod: - if obs_mod == "duplicated": - for m in ["mod1", "mod2"]: - # Index does not support mutable operations - obs_names = mods[m].obs_names.values.copy() - obs_names[1] = obs_names[0] - mods[m].obs_names = obs_names - elif obs_mod == "extreme_duplicated": - obs_names = mods["mod1"].obs_names.to_numpy() - obs_names[:-1] = obs_names[0] - mods["mod1"].obs_names = obs_names + del mdata.mod[modnames[0]] + mdata.update() - return mods + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" + + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) + assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all() + assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all() + assert (getattr(mdata, attrm)["test"] == fulltestm[keptmask, :]).all() + assert (getattr(mdata, oattrm)["test"] == fullotestm[keptomask, :]).all() + + fullbatch = getattr(mdata, attr)["batch"] + fullobatch = getattr(mdata, oattr)["batch"] + fulltestm = getattr(mdata, attrm)["test"] + fullotestm = getattr(mdata, oattrm)["test"] + keptmask = getattr(mdata, f"{attr}map")[modnames[1]].reshape(-1) > 0 + keptomask = getattr(mdata, f"{oattr}map")[modnames[1]].reshape(-1) > 0 + + del mdata.mod[modnames[2]] + mdata.update() + + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) + assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all() + assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all() + assert (getattr(mdata, attrm)["test"] == fulltestm[keptmask, :]).all() + assert (getattr(mdata, oattrm)["test"] == fullotestm[keptomask, :]).all() + + def test_update_intersecting(self, modalities, axis): + """ + Update should work when + - obs_names are the same across modalities, + - there are intersecting var_names, + which are unique in each modality + """ + attr = "obs" if axis == 0 else "var" + oattr = "var" if axis == 0 else "obs" + for m, mod in modalities.items(): + setattr( + mod, + f"{oattr}_names", + [ + f"{m}_{oattr}{j}" if j != 0 else f"{oattr}_{j}" + for j in range(mod.shape[1 - axis]) + ], + ) + + mdata = MuData(modalities, axis=axis) + + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" + + # names along non-axis are concatenated + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in modalities.values()) + assert ( + getattr(mdata, f"{oattr}_names") + == reduce( + lambda x, y: x.append(y), + (getattr(mod, f"{oattr}_names") for mod in modalities.values()), + ) + ).all() + + # names along axis are unioned + axisnames = reduce( + lambda x, y: x.union(y, sort=False), + (getattr(mod, f"{attr}_names") for mod in modalities.values()), + ) + assert mdata.shape[axis] == axisnames.shape[0] + assert (getattr(mdata, f"{attr}_names") == axisnames).all() + + def test_update_after_filter_obs_adata(self, mdata, axis): + """ + Check for muon issue #44. + """ + # Replicate in-place filtering in muon: + # mu.pp.filter_obs(mdata['mod1'], 'min_count', lambda x: (x < -2)) + + old_obsnames = mdata.obs_names + old_varnames = mdata.var_names + + filtermask = mdata["mod3"].obs["min_count"] < -2 + fullfiltermask = mdata.obsmap["mod3"].copy() > 0 + fullfiltermask[fullfiltermask] = filtermask + keptmask = (mdata.obsmap["mod1"] > 0) | (mdata.obsmap["mod2"] > 0) | fullfiltermask + + some_obs_names = mdata[keptmask, :].obs_names.values[:2] + true_obsm_values = self.get_attrm_values(mdata[keptmask], "obs", "test", some_obs_names) + + mdata.mod["mod3"] = mdata["mod3"][mdata["mod3"].obs["min_count"] < -2].copy() + mdata.update() + + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" + + assert mdata.obs["batch"].isna().sum() == 0 + + assert (mdata.var_names == old_varnames).all() + if axis == 0: + # check if the order is preserved + assert (mdata.obs_names == old_obsnames[old_obsnames.isin(mdata.obs_names)]).all() + + test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) + assert (true_obsm_values == test_obsm_values).all() + + def test_update_after_obs_reordered(self, mdata): + """ + Update should work if obs are reordered. + """ + some_obs_names = mdata.obs_names.values[:2] + + true_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) + + mdata.mod["mod1"] = mdata["mod1"][::-1].copy() + mdata.update() + + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" + + test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) + + assert (true_obsm_values == test_obsm_values).all() @pytest.mark.usefixtures("filepath_h5mu") -class TestMuData: - @pytest.mark.parametrize("obs_mod", ["unique", "extreme_duplicated"]) +class TestMuDataLegacy: + @pytest.mark.parametrize("obs_mod", ["unique"]) @pytest.mark.parametrize("obs_across", ["intersecting"]) @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) def test_update_simple(self, modalities): @@ -179,37 +427,40 @@ def test_update_intersecting(self, modalities): @pytest.mark.parametrize("obs_mod", ["unique"]) @pytest.mark.parametrize("obs_across", ["intersecting"]) @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) - def test_update_after_filter_obs_adata(self, mdata): + def test_update_after_filter_obs_adata(self, mdata_legacy): """ Check for muon issue #44. """ # Replicate in-place filtering in muon: # mu.pp.filter_obs(mdata['mod1'], 'min_count', lambda x: (x < -2)) - mdata.mod["mod1"] = mdata["mod1"][mdata["mod1"].obs["min_count"] < -2].copy() - mdata.update() - assert mdata.obs["batch"].isna().sum() == 0 + mdata_legacy.mod["mod1"] = mdata_legacy["mod1"][ + mdata_legacy["mod1"].obs["min_count"] < -2 + ].copy() + old_obsnames = mdata_legacy.obs_names + mdata_legacy.update() + assert mdata_legacy.obs["batch"].isna().sum() == 0 @pytest.mark.parametrize("obs_mod", ["unique", "extreme_duplicated"]) @pytest.mark.parametrize("obs_across", ["intersecting"]) @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) - def test_update_after_obs_reordered(self, mdata): + def test_update_after_obs_reordered(self, mdata_legacy): """ Update should work if obs are reordered. """ - mdata.obsm["test_obsm"] = np.random.normal(size=(mdata.n_obs, 2)) + mdata_legacy.obsm["test_obsm"] = np.random.normal(size=(mdata_legacy.n_obs, 2)) - some_obs_names = mdata.obs_names.values[:2] + some_obs_names = mdata_legacy.obs_names.values[:2] true_obsm_values = [ - mdata.obsm["test_obsm"][np.where(mdata.obs_names.values == name)[0][0]] + mdata_legacy.obsm["test_obsm"][np.where(mdata_legacy.obs_names.values == name)[0][0]] for name in some_obs_names ] - mdata.mod["mod1"] = mdata["mod1"][::-1].copy() - mdata.update() + mdata_legacy.mod["mod1"] = mdata_legacy["mod1"][::-1].copy() + mdata_legacy.update() test_obsm_values = [ - mdata.obsm["test_obsm"][np.where(mdata.obs_names == name)[0][0]] + mdata_legacy.obsm["test_obsm"][np.where(mdata_legacy.obs_names == name)[0][0]] for name in some_obs_names ]