diff --git a/python/seqpro/_numba.py b/python/seqpro/_numba.py index ba724d1..b4919de 100644 --- a/python/seqpro/_numba.py +++ b/python/seqpro/_numba.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Optional, Union, overload import numba as nb @@ -132,3 +134,25 @@ def gufunc_translate( if (seq_kmers == kmer_keys[i]).all(): res[0] = kmer_values[i] # type: ignore break + + +@nb.guvectorize( + ["(u1, u1[:], u1[:])"], + "(),(n)->()", + nopython=True, + cache=True, +) +def gufunc_complement_bytes( + seq: NDArray[np.uint8], + complement_map: NDArray[np.uint8], + res: NDArray[np.uint8] | None = None, +) -> NDArray[np.uint8]: # type: ignore + res[0] = complement_map[seq] # type: ignore + + +_COMP = np.frombuffer(bytes.maketrans(b"ACGT", b"TGCA"), np.uint8) + + +@nb.vectorize(["u1(u1)"], nopython=True, cache=True) +def ufunc_comp_dna(seq: NDArray[np.uint8]) -> NDArray[np.uint8]: + return _COMP[seq] diff --git a/python/seqpro/_utils.py b/python/seqpro/_utils.py index 15cbb0e..502e902 100644 --- a/python/seqpro/_utils.py +++ b/python/seqpro/_utils.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Optional, TypeVar, Union, cast, overload +from typing import TypeVar, Union, cast, overload import numpy as np from numpy.typing import NDArray +from typing_extensions import TypeGuard NestedStr = Union[bytes, str, list["NestedStr"]] """String or nested list of strings""" @@ -13,20 +14,22 @@ SeqType = Union[NestedStr, NDArray[Union[np.str_, np.object_, np.bytes_, np.uint8]]] +DTYPE = TypeVar("DTYPE", bound=np.generic) -@overload -def cast_seqs(seqs: NDArray[np.uint8]) -> NDArray[np.uint8]: ... +def is_dtype( + obj: object, dtype: DTYPE | np.dtype[DTYPE] | type[DTYPE] +) -> TypeGuard[NDArray[DTYPE]]: + return isinstance(obj, np.ndarray) and np.issubdtype(obj.dtype, dtype) + +@overload +def cast_seqs(seqs: NDArray[np.uint8]) -> NDArray[np.uint8]: ... @overload def cast_seqs(seqs: StrSeqType) -> NDArray[np.bytes_]: ... - - @overload -def cast_seqs(seqs: SeqType) -> NDArray[Union[np.bytes_, np.uint8]]: ... - - -def cast_seqs(seqs: SeqType) -> NDArray[Union[np.bytes_, np.uint8]]: +def cast_seqs(seqs: SeqType) -> NDArray[np.bytes_ | np.uint8]: ... +def cast_seqs(seqs: SeqType) -> NDArray[np.bytes_ | np.uint8]: """Cast any sequence type to be a NumPy array of ASCII characters (or left alone as 8-bit unsigned integers if the input is OHE). @@ -54,8 +57,8 @@ def cast_seqs(seqs: SeqType) -> NDArray[Union[np.bytes_, np.uint8]]: def check_axes( seqs: SeqType, - length_axis: Optional[Union[int, bool]] = None, - ohe_axis: Optional[Union[int, bool]] = None, + length_axis: int | bool | None = None, + ohe_axis: int | bool | None = None, ): """Raise errors if length_axis or ohe_axis is missing when they're needed. Pass False to corresponding axis to not check for it. @@ -63,16 +66,16 @@ def check_axes( - ndarray with itemsize == 1 => length axis required. - OHE array => length and OHE axis required. """ + # OHE + if ohe_axis is None and is_dtype(seqs, np.uint8): + raise ValueError("Need an one hot encoding axis to process OHE sequences.") + # bytes or OHE - if length_axis is None and isinstance(seqs, np.ndarray) and seqs.itemsize == 1: + if length_axis is None and is_dtype(seqs, np.bytes_) and seqs.itemsize == 1: raise ValueError( "Need a length axis to process an ndarray with itemsize == 1 (S1, u1)." ) - # OHE - if ohe_axis is None and isinstance(seqs, np.ndarray) and seqs.dtype == np.uint8: - raise ValueError("Need an one hot encoding axis to process OHE sequences.") - # length_axis != ohe_axis if ( isinstance(length_axis, int) @@ -82,9 +85,6 @@ def check_axes( raise ValueError("Length and OHE axis must be different.") -DTYPE = TypeVar("DTYPE", bound=np.generic) - - def array_slice(a: NDArray[DTYPE], axis: int, slice_: slice) -> NDArray[DTYPE]: """Slice an array from a dynamic axis.""" return a[(slice(None),) * (axis % a.ndim) + (slice_,)] diff --git a/python/seqpro/alphabets/__init__.py b/python/seqpro/alphabets/__init__.py index 8bd9d79..3a29e93 100644 --- a/python/seqpro/alphabets/__init__.py +++ b/python/seqpro/alphabets/__init__.py @@ -1,4 +1,4 @@ -from ._alphabets import AminoAlphabet, NucleotideAlphabet +from ._alphabets import DNA, AminoAlphabet, NucleotideAlphabet # NOTE the "*" character is termination i.e. STOP codon canonical_codons_to_aas = { @@ -69,7 +69,6 @@ } -DNA = NucleotideAlphabet(alphabet="ACGT", complement="TGCA") RNA = NucleotideAlphabet(alphabet="ACGU", complement="UGCA") AA = AminoAlphabet(*map(list, zip(*canonical_codons_to_aas.items()))) diff --git a/python/seqpro/alphabets/_alphabets.py b/python/seqpro/alphabets/_alphabets.py index 074e8d6..2f7f5a0 100644 --- a/python/seqpro/alphabets/_alphabets.py +++ b/python/seqpro/alphabets/_alphabets.py @@ -1,10 +1,20 @@ +from __future__ import annotations + +from types import MethodType from typing import Dict, List, Optional, Union, cast, overload import numpy as np from numpy.typing import NDArray +from typing_extensions import assert_never -from .._numba import gufunc_ohe, gufunc_ohe_char_idx, gufunc_translate -from .._utils import SeqType, StrSeqType, cast_seqs, check_axes +from .._numba import ( + gufunc_complement_bytes, + gufunc_ohe, + gufunc_ohe_char_idx, + gufunc_translate, + ufunc_comp_dna, +) +from .._utils import SeqType, StrSeqType, cast_seqs, check_axes, is_dtype class NucleotideAlphabet: @@ -12,10 +22,11 @@ class NucleotideAlphabet: """Alphabet excluding ambiguous characters e.g. "N" for DNA.""" complement: str array: NDArray[np.bytes_] - complement_map: Dict[str, str] - complement_map_bytes: Dict[bytes, bytes] - str_comp_table: Dict[int, str] + complement_map: dict[str, str] + complement_map_bytes: dict[bytes, bytes] + str_comp_table: dict[int, str] bytes_comp_table: bytes + bytes_comp_array: NDArray[np.bytes_] def __init__(self, alphabet: str, complement: str) -> None: """Parse and validate sequence alphabets. @@ -36,9 +47,7 @@ def __init__(self, alphabet: str, complement: str) -> None: self.array = cast( NDArray[np.bytes_], np.frombuffer(self.alphabet.encode("ascii"), "|S1") ) - self.complement_map: Dict[str, str] = dict( - zip(list(self.alphabet), list(self.complement)) - ) + self.complement_map = dict(zip(list(self.alphabet), list(self.complement))) self.complement_map_bytes = { k.encode("ascii"): v.encode("ascii") for k, v in self.complement_map.items() } @@ -46,6 +55,7 @@ def __init__(self, alphabet: str, complement: str) -> None: self.bytes_comp_table = bytes.maketrans( self.alphabet.encode("ascii"), self.complement.encode("ascii") ) + self.bytes_comp_array = np.frombuffer(self.bytes_comp_table, "S1") def __len__(self): return len(self.alphabet) @@ -109,23 +119,29 @@ def decode_ohe( return _alphabet[idx].reshape(shape) - def complement_bytes(self, byte_arr: NDArray[np.bytes_]) -> NDArray[np.bytes_]: + def complement_bytes( + self, byte_arr: NDArray[np.bytes_], out: NDArray[np.bytes_] | None = None + ) -> NDArray[np.bytes_]: """Get reverse complement of byte (S1) array. Parameters ---------- byte_arr : ndarray[bytes] """ - # * a vectorized implementation using np.unique or np.char.translate is NOT - # * faster even for longer alphabets like IUPAC DNA/RNA. Another optimization to - # * try would be using vectorized bit manipulations. - out = byte_arr.copy() - for nuc, comp in self.complement_map_bytes.items(): - out[byte_arr == nuc] = comp - return out + if out is None: + _out = out + else: + _out = out.view(np.uint8) + _out = gufunc_complement_bytes( + byte_arr.view(np.uint8), self.bytes_comp_array.view(np.uint8), _out + ) + return _out.view("S1") def rev_comp_byte( - self, byte_arr: NDArray[np.bytes_], length_axis: int + self, + byte_arr: NDArray[np.bytes_], + length_axis: int, + out: NDArray[np.bytes_] | None = None, ) -> NDArray[np.bytes_]: """Get reverse complement of byte (S1) array. @@ -133,7 +149,7 @@ def rev_comp_byte( ---------- byte_arr : ndarray[bytes] """ - out = self.complement_bytes(byte_arr) + out = self.complement_bytes(byte_arr, out) return np.flip(out, length_axis) def rev_comp_string(self, string: str): @@ -150,6 +166,7 @@ def reverse_complement( seqs: StrSeqType, length_axis: Optional[int] = None, ohe_axis: Optional[int] = None, + out: NDArray[np.bytes_] | None = None, ) -> NDArray[np.bytes_]: ... @overload def reverse_complement( @@ -157,6 +174,7 @@ def reverse_complement( seqs: NDArray[np.uint8], length_axis: Optional[int] = None, ohe_axis: Optional[int] = None, + out: NDArray[np.bytes_] | None = None, ) -> NDArray[np.uint8]: ... @overload def reverse_complement( @@ -164,13 +182,15 @@ def reverse_complement( seqs: SeqType, length_axis: Optional[int] = None, ohe_axis: Optional[int] = None, + out: NDArray[np.bytes_] | None = None, ) -> NDArray[Union[np.bytes_, np.uint8]]: ... def reverse_complement( self, seqs: SeqType, length_axis: Optional[int] = None, ohe_axis: Optional[int] = None, - ) -> NDArray[Union[np.bytes_, np.uint8]]: + out: NDArray[np.bytes_] | None = None, + ) -> NDArray[np.bytes_ | np.uint8]: """Reverse complement a sequence. Parameters @@ -190,14 +210,20 @@ def reverse_complement( seqs = cast_seqs(seqs) - if seqs.dtype == np.uint8: # OHE + if is_dtype(seqs, np.bytes_): + if length_axis is None: + length_axis = -1 + return self.rev_comp_byte(seqs, length_axis, out) + elif is_dtype(seqs, np.uint8): # OHE assert length_axis is not None assert ohe_axis is not None - return np.flip(seqs, axis=(length_axis, ohe_axis)) + _out = np.flip(seqs, axis=(length_axis, ohe_axis)) + if out is not None: + out[:] = _out + _out = out + return _out else: - if length_axis is None: - length_axis = -1 - return self.rev_comp_byte(seqs, length_axis) # type: ignore + assert_never(seqs) # type: ignore class AminoAlphabet: @@ -334,3 +360,25 @@ def decode_ohe( _alphabet = np.concatenate([self.aa_array, [unknown_char.encode("ascii")]]) return _alphabet[idx].reshape(shape) + + +DNA = NucleotideAlphabet("ACGT", "TGCA") + + +# * Monkey patch DNA instance with a faster complement function using +# * a static, const lookup table. The base method is slower because it uses a +# * dynamic lookup table. +def complement_bytes( + self: NucleotideAlphabet, + byte_arr: NDArray[np.bytes_], + out: NDArray[np.bytes_] | None = None, +) -> NDArray[np.bytes_]: + if out is None: + _out = out + else: + _out = out.view(np.uint8) + _out = ufunc_comp_dna(byte_arr.view(np.uint8), _out) # type: ignore + return _out.view("S1") + + +DNA.complement_bytes = MethodType(complement_bytes, DNA) diff --git a/python/seqpro/rag/_array.py b/python/seqpro/rag/_array.py index eea17be..748051c 100644 --- a/python/seqpro/rag/_array.py +++ b/python/seqpro/rag/_array.py @@ -61,7 +61,7 @@ def __init__( if isinstance(data, RagParts): content = _parts_to_content(data) else: - content = _with_ragged(data, highlevel=False) + content = _as_ragged(data, highlevel=False) super().__init__(content, behavior=deepcopy(ak.behavior)) self._parts = unbox(self) type_parts: list[str] = [] @@ -232,7 +232,7 @@ def __getitem__(self, where): if _n_var(arr) == 1: return type(self)(arr) else: - return _without_ragged(arr) + return _as_ak(arr) else: return arr @@ -293,7 +293,7 @@ def reshape(self, *shape: int | None | tuple[int | None, ...]) -> Self: def to_ak(self): """Convert to an Awkward array.""" - arr = _without_ragged(self) + arr = _as_ak(self) arr.behavior = None return arr @@ -331,12 +331,12 @@ def _n_var(arr: ak.Array) -> int: @overload -def _with_ragged( +def _as_ragged( arr: ak.Array | Content, highlevel: Literal[True] = True ) -> ak.Array: ... @overload -def _with_ragged(arr: ak.Array | Content, highlevel: Literal[False]) -> Content: ... -def _with_ragged(arr: ak.Array | Content, highlevel: bool = True) -> ak.Array | Content: +def _as_ragged(arr: ak.Array | Content, highlevel: Literal[False]) -> Content: ... +def _as_ragged(arr: ak.Array | Content, highlevel: bool = True) -> ak.Array | Content: def fn(layout: Content, **kwargs): if isinstance(layout, (ListArray, ListOffsetArray)): return ak.with_parameter( @@ -350,16 +350,12 @@ def fn(layout: Content, **kwargs): @overload -def _without_ragged( +def _as_ak( arr: ak.Array | Ragged[DTYPE], highlevel: Literal[True] = True ) -> ak.Array: ... @overload -def _without_ragged( - arr: ak.Array | Ragged[DTYPE], highlevel: Literal[False] -) -> Content: ... -def _without_ragged( - arr: ak.Array | Ragged[DTYPE], highlevel: bool = True -) -> ak.Array | Content: +def _as_ak(arr: ak.Array | Ragged[DTYPE], highlevel: Literal[False]) -> Content: ... +def _as_ak(arr: ak.Array | Ragged[DTYPE], highlevel: bool = True) -> ak.Array | Content: def fn(layout, **kwargs): if isinstance(layout, (ListArray, ListOffsetArray)): return ak.with_parameter(layout, "__list__", None, highlevel=False) diff --git a/tests/test_modifiers.py b/tests/test_modifiers.py index 6e38a0f..f42e6ad 100644 --- a/tests/test_modifiers.py +++ b/tests/test_modifiers.py @@ -2,8 +2,9 @@ import numpy as np import seqpro as sp -from seqpro._modifiers import _align_axes, _slice_kmers -from seqpro._utils import check_axes +from pytest_cases import parametrize_with_cases +from seqpro._modifiers import _align_axes, _slice_kmers, reverse_complement +from seqpro._utils import cast_seqs, check_axes def test_align_axes(): @@ -143,3 +144,79 @@ def test_k_shuffle(): shuffled_counts = _count_kmers(shuffled, k, length_axis) assert counts == shuffled_counts + + +# Test cases for reverse_complement +class ReverseComplementCases: + def case_single_string(self): + """Test single string sequence.""" + seq = "ATCG" + # ATCG -> CGAT (reverse complement) + expected = cast_seqs("CGAT") + return seq, expected, None, None + + def case_list_of_strings(self): + """Test list of string sequences.""" + seqs = ["ATCG", "GCTA"] + # ATCG -> CGAT, GCTA -> TAGC + expected = cast_seqs(["CGAT", "TAGC"]) + return seqs, expected, None, None + + def case_byte_array_1d(self): + """Test 1D byte array.""" + seqs = cast_seqs("ATCG") + # ATCG -> CGAT + expected = cast_seqs("CGAT") + return seqs, expected, -1, None + + def case_byte_array_2d(self): + """Test 2D byte array.""" + seqs = cast_seqs(["ATCG", "GCTA"]) + # ATCG -> CGAT, GCTA -> TAGC + expected = cast_seqs(["CGAT", "TAGC"]) + return seqs, expected, -1, None + + def case_byte_array_3d(self): + """Test 3D byte array with last axis as length.""" + seqs = cast_seqs([["AT", "CG"], ["GC", "TA"]]) + # AT -> AT (palindrome), CG -> CG (palindrome) + # GC -> GC (palindrome), TA -> TA (palindrome) + expected = cast_seqs([["AT", "CG"], ["GC", "TA"]]) + return seqs, expected, -1, None + + def case_ohe_array_2d(self): + """Test 2D one-hot encoded array.""" + # Create OHE sequence: "AC" + # Shape: (2, 4) - length axis x alphabet axis + seqs = sp.DNA.ohe("AC") + # Reverse complement: "AC" -> "GT" + expected = sp.DNA.ohe("GT") + return seqs, expected, 0, 1 + + def case_ohe_array_3d(self): + """Test 3D one-hot encoded array.""" + # Create two sequences: "AC" and "GT" + # Shape: (2, 2, 4) - batch x length x alphabet + seqs = sp.DNA.ohe(["AC", "GT"]) + # Reverse complement: + # "AC" -> "GT" + # "GT" -> "AC" + expected = sp.DNA.ohe(["GT", "AC"]) + return seqs, expected, 1, 2 + + def case_palindrome(self): + """Test palindromic sequence (same as its reverse complement).""" + seq = "GAATTC" # EcoRI site - palindrome + expected = cast_seqs("GAATTC") + return seq, expected, None, None + + +@parametrize_with_cases( + "seqs,expected,length_axis,ohe_axis", cases=ReverseComplementCases +) +def test_reverse_complement(seqs, expected, length_axis, ohe_axis): + """Test reverse_complement with various input types and configurations.""" + result = reverse_complement( + seqs, sp.DNA, length_axis=length_axis, ohe_axis=ohe_axis + ) + np.testing.assert_array_equal(result, expected)