diff --git a/.gitignore b/.gitignore index b5fa6f0..ef8e533 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ archive/ .vscode/ .ruff_cache/ .benchmarks/ +scripts/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/Cargo.lock b/Cargo.lock index a5ebd97..c8a02bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,18 +1,18 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "anyhow" -version = "1.0.79" +version = "1.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" +checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" [[package]] name = "autocfg" -version = "1.1.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "bitflags" @@ -22,40 +22,34 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "cfg-if" -version = "1.0.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" [[package]] name = "crossbeam-deque" -version = "0.8.4" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fca89a0e215bab21874660c67903c5f143333cab1da83d041c7ded6053774751" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" dependencies = [ - "cfg-if", "crossbeam-epoch", "crossbeam-utils", ] [[package]] name = "crossbeam-epoch" -version = "0.9.17" +version = "0.9.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e3681d554572a651dda4186cd47240627c3d0114d45a95f6ad27f2f22e7548d" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" dependencies = [ - "autocfg", - "cfg-if", "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.18" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3a430a770ebd84726f584a90ee7f020d28db52c6d02138900f22341f866d39c" -dependencies = [ - "cfg-if", -] +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "darling" @@ -94,18 +88,18 @@ dependencies = [ [[package]] name = "derive_builder" -version = "0.13.0" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "660047478bc508c0fde22c868991eec0c40a63e48d610befef466d48e2bee574" +checksum = "8f59169f400d8087f238c5c0c7db6a28af18681717f3b623227d92f397e938c7" dependencies = [ "derive_builder_macro", ] [[package]] name = "derive_builder_core" -version = "0.13.0" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b217e6dd1011a54d12f3b920a411b5abd44b1716ecfe94f5f2f2f7b52e08ab7" +checksum = "a4ec317cc3e7ef0928b0ca6e4a634a4d6c001672ae210438cf114a83e56b018d" dependencies = [ "darling", "proc-macro2", @@ -115,9 +109,9 @@ dependencies = [ [[package]] name = "derive_builder_macro" -version = "0.13.0" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5f77d7e20ac9153428f7ca14a88aba652adfc7a0ef0a06d654386310ef663b" +checksum = "870368c3fb35b8031abb378861d4460f573b92238ec2152c927a21f77e3e0127" dependencies = [ "derive_builder_core", "syn 1.0.109", @@ -125,9 +119,9 @@ dependencies = [ [[package]] name = "either" -version = "1.9.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "fnv" @@ -293,9 +287,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.72" +version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a293318316cf6478ec1ad2a21c49390a8d5b5eae9fab736467d93fbc0edc29c5" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" dependencies = [ "unicode-ident", ] @@ -346,7 +340,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.43", + "syn 2.0.106", ] [[package]] @@ -358,14 +352,14 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.43", + "syn 2.0.106", ] [[package]] name = "quote" -version = "1.0.34" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22a37c9326af5ed140c86a46655b5278de879853be5573c01df185b6f49a580a" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" dependencies = [ "proc-macro2", ] @@ -408,9 +402,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.8.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" dependencies = [ "either", "rayon-core", @@ -418,9 +412,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -487,9 +481,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.43" +version = "2.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee659fb5f3d355364e1f3e5bc10fb82068efbf824a1e9d1c9504244a6469ad53" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" dependencies = [ "proc-macro2", "quote", @@ -504,29 +498,29 @@ checksum = "14c39fd04924ca3a864207c66fc2cd7d22d7c016007f9ce846cbb9326331930a" [[package]] name = "thiserror" -version = "1.0.53" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2cd5904763bad08ad5513ddbb12cf2ae273ca53fa9f68e843e236ec6dfccc09" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.53" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dcf4a824cce0aeacd6f38ae6f24234c8e80d68632338ebaa1443b5df9e29e19" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.43", + "syn 2.0.106", ] [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" [[package]] name = "unindent" @@ -599,6 +593,6 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "xxhash-rust" -version = "0.8.8" +version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53be06678ed9e83edb1745eb72efc0bbcd7b5c3c35711a860906aed827a13d61" +checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" diff --git a/Cargo.toml b/Cargo.toml index 0cd331e..817ae88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,14 +10,14 @@ name = "seqpro" crate-type = ["cdylib", "rlib"] [dependencies] -anyhow = "1.0.79" -derive_builder = "0.13.0" +anyhow = "1.0.99" +derive_builder = "0.13.1" ndarray = { version = "0.15.6", features = ["rayon"] } numpy = "0.20.0" rand = { version = "0.8.5", features = ["small_rng"] } -rayon = "1.8.0" -thiserror = "1.0.53" -xxhash-rust = { version = "0.8.8", features = ["xxh3"] } +rayon = "1.11.0" +thiserror = "1.0.69" +xxhash-rust = { version = "0.8.15", features = ["xxh3"] } [dependencies.pyo3] version = "0.20" diff --git a/python/seqpro/_modifiers.py b/python/seqpro/_modifiers.py index a59e39f..c2fc7f1 100644 --- a/python/seqpro/_modifiers.py +++ b/python/seqpro/_modifiers.py @@ -39,11 +39,11 @@ def reverse_complement( def k_shuffle( seqs: SeqType, k: int, + alphabet: NucleotideAlphabet, *, length_axis: int | None = None, ohe_axis: int | None = None, seed: int | np.random.Generator | None = None, - alphabet: NucleotideAlphabet | None = None, ) -> NDArray[Union[np.bytes_, np.uint8]]: """Shuffle sequences while preserving k-let frequencies. @@ -52,6 +52,8 @@ def k_shuffle( seqs : SeqType k : int Size of k-lets to preserve frequencies of. + alphabet : NucleotideAlphabet + Alphabet, needed for OHE sequence input. length_axis : Optional[int], optional Needed for array input. Axis that corresponds to the length of sequences. ohe_axes : Optional[int], optional @@ -59,14 +61,12 @@ def k_shuffle( the same size as the length of the alphabet. seed : int, np.random.Generator, optional Seed or generator for shuffling. - alphabet : Optional[NucleotideAlphabet], optional - Alphabet, needed for OHE sequence input. """ check_axes(seqs, length_axis, ohe_axis) if isinstance(seed, np.random.Generator): - seed = seed.integers(0, np.iinfo(np.int32).max) + seed = seed.integers(0, np.iinfo(np.int32).max) # type: ignore seqs = cast_seqs(seqs) @@ -78,16 +78,13 @@ def k_shuffle( assert ohe_axis is not None seqs = cast(NDArray[np.uint8], seqs) ohe = True - if alphabet is None: - raise ValueError("Need an alphabet to process OHE sequences.") seqs = alphabet.decode_ohe(seqs, ohe_axis=ohe_axis) else: ohe = False seqs = np.moveaxis(seqs, length_axis, -1) # length must be final - seqs = np.ascontiguousarray(seqs) # must be contiguous - shuffled = _k_shuffle(seqs.view("u1"), k, seed).view("S1") + shuffled = _k_shuffle(seqs.view("u1"), k, len(alphabet), seed).view("S1") shuffled = np.moveaxis(shuffled, -1, length_axis) # put length back where it was diff --git a/python/seqpro/_utils.py b/python/seqpro/_utils.py index 61d6e6b..15cbb0e 100644 --- a/python/seqpro/_utils.py +++ b/python/seqpro/_utils.py @@ -1,9 +1,11 @@ -from typing import List, Optional, TypeVar, Union, cast, overload +from __future__ import annotations + +from typing import Optional, TypeVar, Union, cast, overload import numpy as np from numpy.typing import NDArray -NestedStr = Union[bytes, str, List["NestedStr"]] +NestedStr = Union[bytes, str, list["NestedStr"]] """String or nested list of strings""" StrSeqType = Union[NestedStr, NDArray[Union[np.str_, np.object_, np.bytes_]]] diff --git a/python/seqpro/bed.py b/python/seqpro/bed.py index 14cb89f..c7138dc 100644 --- a/python/seqpro/bed.py +++ b/python/seqpro/bed.py @@ -33,15 +33,13 @@ def sort(bed: pl.DataFrame): """Sort a BED-like DataFrame by chromosome, start, and end position, using the natural order of chromosome names e.g. 1, 2, ..., 10, ...""" - contigs = bed["chrom"].unique() - with pl.StringCache(): - pl.Series(natsorted(contigs), dtype=pl.Categorical) - bed = bed.sort( - pl.col("chrom").cast(pl.Categorical), - "chromStart", - "chromEnd", - maintain_order=True, - ) + order = natsorted(bed["chrom"].unique()) + bed = bed.sort( + pl.col("chrom").cast(pl.Enum(order)), + "chromStart", + "chromEnd", + maintain_order=True, + ) return bed @@ -100,7 +98,7 @@ def to_pyr(bedlike: pl.DataFrame) -> pr.PyRanges: "strand": "Strand", }, strict=False, - ).to_pandas(use_pyarrow_extension_array=True) + ).to_pandas() ) diff --git a/python/seqpro/rag/_array.py b/python/seqpro/rag/_array.py index 5c63eec..93fe0c1 100644 --- a/python/seqpro/rag/_array.py +++ b/python/seqpro/rag/_array.py @@ -62,16 +62,13 @@ def __init__( content = _parts_to_content(data) else: content = _with_ragged(data, highlevel=False) - super().__init__(content, behavior=deepcopy(behavior)) + super().__init__(content, behavior=deepcopy(ak.behavior)) self._parts = unbox(self) type_parts: list[str] = [] - name = self.parts.data.dtype.name - if name == "bytes8": - name = "bytes" type_parts.append("var") type_parts.extend([str(s) for s in self.shape[self.rag_dim + 1 :]]) - type_parts.append(f"rag[{name}]") - self.behavior["__typestr__", "ragged"] = " * ".join(type_parts) # type: ignore + type_parts.append(Ragged.__name__) + self.behavior["__typestr__", Ragged.__name__] = " * ".join(type_parts) # type: ignore @staticmethod def from_offsets( @@ -312,10 +309,6 @@ def apply( return Ragged(parts) -behavior = deepcopy(ak.behavior) -behavior["*", "ragged"] = Ragged - - def apply_ufunc( ufunc: np.ufunc, method: str, args: tuple[Any, ...], kwargs: dict[str, Any] ): @@ -323,7 +316,8 @@ def apply_ufunc( return Ragged(getattr(ufunc, method)(*args, **kwargs)) -ak.behavior[np.ufunc, "ragged"] = apply_ufunc +ak.behavior["*", Ragged.__name__] = Ragged +ak.behavior[np.ufunc, Ragged.__name__] = apply_ufunc def _n_var(arr: ak.Array) -> int: @@ -343,12 +337,14 @@ def _with_ragged( @overload def _with_ragged(arr: ak.Array | Content, highlevel: Literal[False]) -> Content: ... def _with_ragged(arr: ak.Array | Content, highlevel: bool = True) -> ak.Array | Content: - def fn(layout, **kwargs): + def fn(layout: Content, **kwargs): if isinstance(layout, (ListArray, ListOffsetArray)): - return ak.with_parameter(layout, "__list__", "ragged", highlevel=False) + return ak.with_parameter( + layout, "__list__", Ragged.__name__, highlevel=False + ) else: - for k in layout.parameters: - del layout.parameters[k] + if layout._parameters is not None: + layout._parameters = None return ak.transform(fn, arr, highlevel=highlevel) # type: ignore @@ -413,6 +409,7 @@ def unbox( node = cast(Content, arr.layout) shape: list[int | None] = [len(node)] + n_ragged = 0 offsets = None while isinstance(node, (ListArray, ListOffsetArray, RegularArray)): @@ -420,6 +417,7 @@ def unbox( shape.append(node.size) else: shape.append(None) + n_ragged += 1 if isinstance(node, ListOffsetArray): offsets = node.offsets.data else: @@ -430,11 +428,14 @@ def unbox( node = node.content + if n_ragged != 1: + raise ValueError(f"Expected 1 ragged dimension, got {n_ragged}") + if isinstance(node, EmptyArray): node = node.to_NumpyArray(dtype=np.float64) if isinstance(node, NumpyArray): - data = cast(NDArray, node.data) + data = cast(NDArray, node.data) # type: ignore if node.parameter("__array__") == "byte": # view uint8 as bytes @@ -474,7 +475,9 @@ def _parts_to_content(parts: RagParts[DTYPE]) -> Content: layout = ListArray( Index(parts.offsets[0, :]), Index(parts.offsets[1, :]), layout ) - layout = ak.with_parameter(layout, "__list__", "ragged", highlevel=False) + layout = ak.with_parameter( + layout, "__list__", Ragged.__name__, highlevel=False + ) else: layout = RegularArray(layout, size) diff --git a/src/kshuffle.rs b/src/kshuffle.rs index 1c6e4c6..f906faf 100644 --- a/src/kshuffle.rs +++ b/src/kshuffle.rs @@ -30,7 +30,9 @@ struct Vertex { } struct HEntry { + /// Number of unique k-mers that come before this k-mer i_vertices: usize, + /// First index where k-mer appears i_sequence: usize, } @@ -38,6 +40,7 @@ pub fn k_shuffle( seqs: ArrayView, k: usize, seed: Option, + alphabet_size: usize, ) -> Array { let mut out = unsafe { Array::uninit(seqs.raw_dim()).assume_init() }; @@ -46,7 +49,7 @@ pub fn k_shuffle( .into_iter() .zip(seqs.rows()) .par_bridge() - .map(|(out_row, row)| k_shuffle1(row, k, seed, out_row)) + .map(|(out_row, row)| k_shuffle1(row, k, seed, out_row, alphabet_size)) .collect::>(); for result in results { @@ -57,17 +60,18 @@ pub fn k_shuffle( } fn k_shuffle1( - arr: ArrayView1, + seq: ArrayView1, k: usize, seed: Option, mut out: ArrayViewMut1, + alphabet_size: usize, ) -> Result<()> { let seed = seed.unwrap_or_else(|| rand::thread_rng().gen()); let mut rng = SmallRng::seed_from_u64(seed); - let l = arr.len(); + let l = seq.len(); if k >= l { - arr.assign_to(out); + seq.assign_to(out); return Ok(()); } @@ -76,36 +80,38 @@ fn k_shuffle1( } if k == 1 { - arr.assign_to(&mut out); + seq.assign_to(&mut out); out.as_slice_mut().unwrap().shuffle(&mut rng); return Ok(()); } let n_lets = l - k + 2; - let mut htable = HashMap::with_capacity_and_hasher(n_lets, Xxh3Builder::new()); + let max_uniq_lets = n_lets.min(alphabet_size.pow((k - 1) as u32)); + let mut htable = HashMap::with_capacity_and_hasher(max_uniq_lets, Xxh3Builder::new()); // find distinct verticess let mut n_vertices = 0; - for (pos, kmer) in arr.windows(k - 1).into_iter().enumerate() { - htable.entry(kmer.to_vec()).or_insert_with(|| { - let hentry = HEntry { - i_vertices: n_vertices, - i_sequence: pos, - }; - n_vertices += 1; - hentry - }); + for (pos, kmer) in seq.windows(k - 1).into_iter().enumerate() { + if n_vertices < max_uniq_lets { + htable.entry(kmer.to_vec()).or_insert_with(|| { + let hentry = HEntry { + i_vertices: n_vertices, + i_sequence: pos, + }; + n_vertices += 1; + hentry + }); + } } - let n_vertices = htable.len(); - let root = arr.slice(s![-(k as isize - 1)..]).to_vec(); + let root = seq.slice(s![-(k as isize - 1)..]).to_vec(); let mut indices = vec![0 as usize; n_lets - 1]; let mut vertices = (0..n_vertices) .map(|_| VertexBuilder::default().intree(false).n_indices(0).next(0)) .collect::>(); // set i_sequence and n_indices for each vertex - for (i, kmer) in arr.windows(k - 1).into_iter().enumerate() { + for (i, kmer) in seq.windows(k - 1).into_iter().enumerate() { let hentry = htable.get(&kmer.to_vec()).unwrap(); let v = &mut vertices[hentry.i_vertices]; @@ -130,11 +136,11 @@ fn k_shuffle1( .collect::>(); // populate indices for each vertex - for (kmer1, kmer2) in arr + for (kmer1, kmer2) in seq .slice(s![..-1]) .windows(k - 1) .into_iter() - .zip(arr.slice(s![1..]).windows(k - 1)) + .zip(seq.slice(s![1..]).windows(k - 1)) { let eu = htable.get(&kmer1.to_vec()).unwrap(); let ev = htable.get(&kmer2.to_vec()).unwrap(); @@ -149,10 +155,20 @@ fn k_shuffle1( // Wilson algorithm for random arborescence let root_idx = htable.get(&root).unwrap().i_vertices; - { - let root_vertex = &mut vertices[root_idx]; - root_vertex.intree = true; - } + wilson_random_spanning_tree(&mut vertices, &indices, root_idx, &mut rng); + random_walk(&mut vertices, &mut indices, root_idx, &mut rng, seq, k, out); + + Ok(()) +} + +fn wilson_random_spanning_tree( + vertices: &mut Vec, + indices: &Vec, + root_idx: usize, + rng: &mut R, +) { + let root_vertex = &mut vertices[root_idx]; + root_vertex.intree = true; for i in 0..vertices.len() { // let mut u = &mut vertices[i]; @@ -195,8 +211,17 @@ fn k_shuffle1( } } } +} - // shuffle indices to prepare for walk +fn random_walk( + vertices: &mut Vec, + indices: &mut Vec, + root_idx: usize, + rng: &mut R, + seq: ArrayView1, + k: usize, + mut out: ArrayViewMut1, +) { let mut j; for (i, u) in vertices.iter_mut().enumerate() { if i != root_idx { @@ -205,16 +230,16 @@ fn k_shuffle1( indices[u.idx_offset + idx] = indices[u.idx_offset + u.next]; let next = u.next; indices[u.idx_offset + next] = j; - indices[u.idx_offset..u.idx_offset + idx].shuffle(&mut rng); + indices[u.idx_offset..u.idx_offset + idx].shuffle(rng); } else { - indices[u.idx_offset..u.idx_offset + u.n_indices].shuffle(&mut rng); + indices[u.idx_offset..u.idx_offset + u.n_indices].shuffle(rng); } u.i_indices = 0; } // walk the graph let out = out.as_slice_mut().unwrap(); - out[..k - 1].clone_from_slice(arr.slice(s![..k - 1]).as_slice().unwrap()); + out[..k - 1].clone_from_slice(seq.slice(s![..k - 1]).as_slice().unwrap()); let mut i = k - 1; let mut u_idx = 0; loop { @@ -229,21 +254,19 @@ fn k_shuffle1( if u_idx != v_idx { let v = &vertices[v_idx]; j = v.i_sequence + k - 2; - out[i] = arr[j]; + out[i] = seq[j]; i += 1; vertices[u_idx].i_indices += 1; } else { let v = &mut vertices[v_idx]; j = v.i_sequence + k - 2; - out[i] = arr[j]; + out[i] = seq[j]; i += 1; v.i_indices += 1; } } u_idx = v_idx; } - - Ok(()) } #[cfg(test)] @@ -264,11 +287,12 @@ mod test { #[test] fn same_freq() { let k = 2; + let alphabet_size = 4; let seq = ArrayView1::from(b"AATAT"); let freqs = kmer_frequencies(seq.as_slice().unwrap(), k); let mut shuffled = unsafe { Array::uninit(seq.len()).assume_init() }; - let res = k_shuffle1(seq.view(), k, Some(1), shuffled.view_mut()); + let res = k_shuffle1(seq.view(), k, Some(1), shuffled.view_mut(), alphabet_size); assert!(res.is_ok()); let shuffled_freqs = kmer_frequencies(shuffled.as_slice().unwrap(), k); diff --git a/src/lib.rs b/src/lib.rs index 25bcda3..9d188e7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ fn seqpro(_py: Python, m: &PyModule) -> PyResult<()> { #[pyfunction] /// Shuffle sequences while preserving k-mer frequencies. -/// +/// /// Parameters /// ---------- /// seqs : NDArray[uint8] @@ -20,15 +20,18 @@ fn seqpro(_py: Python, m: &PyModule) -> PyResult<()> { /// being the sequence length. /// k : int /// Length of k-mers to preserve frequencies of. +/// alphabet_size : int +/// Number of unique characters in the alphabet. /// seed : int, optional /// Seed for the random number generator. fn _k_shuffle<'py>( py: Python<'py>, seqs: PyReadonlyArray<'py, u8, IxDyn>, k: usize, + alphabet_size: usize, seed: Option, ) -> &'py PyArray { let seqs = seqs.as_array(); - let out = kshuffle::k_shuffle(seqs, k, seed); + let out = kshuffle::k_shuffle(seqs, k, seed, alphabet_size); out.into_pyarray(py) } diff --git a/tests/test_modifiers.py b/tests/test_modifiers.py index 37778bb..6e38a0f 100644 --- a/tests/test_modifiers.py +++ b/tests/test_modifiers.py @@ -139,7 +139,7 @@ def test_k_shuffle(): for k in range(2, 5): counts = _count_kmers(seqs, k, length_axis) - shuffled = sp.k_shuffle(seqs, k, length_axis=length_axis, seed=seed) + shuffled = sp.k_shuffle(seqs, k, sp.DNA, length_axis=length_axis, seed=seed) shuffled_counts = _count_kmers(shuffled, k, length_axis) assert counts == shuffled_counts