Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions python/seqpro/_numba.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Optional, Union, overload

import numba as nb
Expand Down Expand Up @@ -132,3 +134,25 @@ def gufunc_translate(
if (seq_kmers == kmer_keys[i]).all():
res[0] = kmer_values[i] # type: ignore
break


@nb.guvectorize(
["(u1, u1[:], u1[:])"],
"(),(n)->()",
nopython=True,
cache=True,
)
def gufunc_complement_bytes(
seq: NDArray[np.uint8],
complement_map: NDArray[np.uint8],
res: NDArray[np.uint8] | None = None,
) -> NDArray[np.uint8]: # type: ignore
res[0] = complement_map[seq] # type: ignore


_COMP = np.frombuffer(bytes.maketrans(b"ACGT", b"TGCA"), np.uint8)


@nb.vectorize(["u1(u1)"], nopython=True, cache=True)
def ufunc_comp_dna(seq: NDArray[np.uint8]) -> NDArray[np.uint8]:
return _COMP[seq]
38 changes: 19 additions & 19 deletions python/seqpro/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from typing import Optional, TypeVar, Union, cast, overload
from typing import TypeVar, Union, cast, overload

import numpy as np
from numpy.typing import NDArray
from typing_extensions import TypeGuard

NestedStr = Union[bytes, str, list["NestedStr"]]
"""String or nested list of strings"""
Expand All @@ -13,20 +14,22 @@

SeqType = Union[NestedStr, NDArray[Union[np.str_, np.object_, np.bytes_, np.uint8]]]

DTYPE = TypeVar("DTYPE", bound=np.generic)

@overload
def cast_seqs(seqs: NDArray[np.uint8]) -> NDArray[np.uint8]: ...

def is_dtype(
obj: object, dtype: DTYPE | np.dtype[DTYPE] | type[DTYPE]
) -> TypeGuard[NDArray[DTYPE]]:
return isinstance(obj, np.ndarray) and np.issubdtype(obj.dtype, dtype)


@overload
def cast_seqs(seqs: NDArray[np.uint8]) -> NDArray[np.uint8]: ...
@overload
def cast_seqs(seqs: StrSeqType) -> NDArray[np.bytes_]: ...


@overload
def cast_seqs(seqs: SeqType) -> NDArray[Union[np.bytes_, np.uint8]]: ...


def cast_seqs(seqs: SeqType) -> NDArray[Union[np.bytes_, np.uint8]]:
def cast_seqs(seqs: SeqType) -> NDArray[np.bytes_ | np.uint8]: ...
def cast_seqs(seqs: SeqType) -> NDArray[np.bytes_ | np.uint8]:
"""Cast any sequence type to be a NumPy array of ASCII characters (or left alone as
8-bit unsigned integers if the input is OHE).

Expand Down Expand Up @@ -54,25 +57,25 @@ def cast_seqs(seqs: SeqType) -> NDArray[Union[np.bytes_, np.uint8]]:

def check_axes(
seqs: SeqType,
length_axis: Optional[Union[int, bool]] = None,
ohe_axis: Optional[Union[int, bool]] = None,
length_axis: int | bool | None = None,
ohe_axis: int | bool | None = None,
):
"""Raise errors if length_axis or ohe_axis is missing when they're needed. Pass
False to corresponding axis to not check for it.

- ndarray with itemsize == 1 => length axis required.
- OHE array => length and OHE axis required.
"""
# OHE
if ohe_axis is None and is_dtype(seqs, np.uint8):
raise ValueError("Need an one hot encoding axis to process OHE sequences.")

# bytes or OHE
if length_axis is None and isinstance(seqs, np.ndarray) and seqs.itemsize == 1:
if length_axis is None and is_dtype(seqs, np.bytes_) and seqs.itemsize == 1:
raise ValueError(
"Need a length axis to process an ndarray with itemsize == 1 (S1, u1)."
)

# OHE
if ohe_axis is None and isinstance(seqs, np.ndarray) and seqs.dtype == np.uint8:
raise ValueError("Need an one hot encoding axis to process OHE sequences.")

# length_axis != ohe_axis
if (
isinstance(length_axis, int)
Expand All @@ -82,9 +85,6 @@ def check_axes(
raise ValueError("Length and OHE axis must be different.")


DTYPE = TypeVar("DTYPE", bound=np.generic)


def array_slice(a: NDArray[DTYPE], axis: int, slice_: slice) -> NDArray[DTYPE]:
"""Slice an array from a dynamic axis."""
return a[(slice(None),) * (axis % a.ndim) + (slice_,)]
3 changes: 1 addition & 2 deletions python/seqpro/alphabets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._alphabets import AminoAlphabet, NucleotideAlphabet
from ._alphabets import DNA, AminoAlphabet, NucleotideAlphabet

# NOTE the "*" character is termination i.e. STOP codon
canonical_codons_to_aas = {
Expand Down Expand Up @@ -69,7 +69,6 @@
}


DNA = NucleotideAlphabet(alphabet="ACGT", complement="TGCA")
RNA = NucleotideAlphabet(alphabet="ACGU", complement="UGCA")
AA = AminoAlphabet(*map(list, zip(*canonical_codons_to_aas.items())))

Expand Down
96 changes: 72 additions & 24 deletions python/seqpro/alphabets/_alphabets.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
from __future__ import annotations

from types import MethodType
from typing import Dict, List, Optional, Union, cast, overload

import numpy as np
from numpy.typing import NDArray
from typing_extensions import assert_never

from .._numba import gufunc_ohe, gufunc_ohe_char_idx, gufunc_translate
from .._utils import SeqType, StrSeqType, cast_seqs, check_axes
from .._numba import (
gufunc_complement_bytes,
gufunc_ohe,
gufunc_ohe_char_idx,
gufunc_translate,
ufunc_comp_dna,
)
from .._utils import SeqType, StrSeqType, cast_seqs, check_axes, is_dtype


class NucleotideAlphabet:
alphabet: str
"""Alphabet excluding ambiguous characters e.g. "N" for DNA."""
complement: str
array: NDArray[np.bytes_]
complement_map: Dict[str, str]
complement_map_bytes: Dict[bytes, bytes]
str_comp_table: Dict[int, str]
complement_map: dict[str, str]
complement_map_bytes: dict[bytes, bytes]
str_comp_table: dict[int, str]
bytes_comp_table: bytes
bytes_comp_array: NDArray[np.bytes_]

def __init__(self, alphabet: str, complement: str) -> None:
"""Parse and validate sequence alphabets.
Expand All @@ -36,16 +47,15 @@ def __init__(self, alphabet: str, complement: str) -> None:
self.array = cast(
NDArray[np.bytes_], np.frombuffer(self.alphabet.encode("ascii"), "|S1")
)
self.complement_map: Dict[str, str] = dict(
zip(list(self.alphabet), list(self.complement))
)
self.complement_map = dict(zip(list(self.alphabet), list(self.complement)))
self.complement_map_bytes = {
k.encode("ascii"): v.encode("ascii") for k, v in self.complement_map.items()
}
self.str_comp_table = str.maketrans(self.complement_map)
self.bytes_comp_table = bytes.maketrans(
self.alphabet.encode("ascii"), self.complement.encode("ascii")
)
self.bytes_comp_array = np.frombuffer(self.bytes_comp_table, "S1")

def __len__(self):
return len(self.alphabet)
Expand Down Expand Up @@ -109,31 +119,37 @@ def decode_ohe(

return _alphabet[idx].reshape(shape)

def complement_bytes(self, byte_arr: NDArray[np.bytes_]) -> NDArray[np.bytes_]:
def complement_bytes(
self, byte_arr: NDArray[np.bytes_], out: NDArray[np.bytes_] | None = None
) -> NDArray[np.bytes_]:
"""Get reverse complement of byte (S1) array.

Parameters
----------
byte_arr : ndarray[bytes]
"""
# * a vectorized implementation using np.unique or np.char.translate is NOT
# * faster even for longer alphabets like IUPAC DNA/RNA. Another optimization to
# * try would be using vectorized bit manipulations.
out = byte_arr.copy()
for nuc, comp in self.complement_map_bytes.items():
out[byte_arr == nuc] = comp
return out
if out is None:
_out = out
else:
_out = out.view(np.uint8)
_out = gufunc_complement_bytes(
byte_arr.view(np.uint8), self.bytes_comp_array.view(np.uint8), _out
)
return _out.view("S1")

def rev_comp_byte(
self, byte_arr: NDArray[np.bytes_], length_axis: int
self,
byte_arr: NDArray[np.bytes_],
length_axis: int,
out: NDArray[np.bytes_] | None = None,
) -> NDArray[np.bytes_]:
"""Get reverse complement of byte (S1) array.

Parameters
----------
byte_arr : ndarray[bytes]
"""
out = self.complement_bytes(byte_arr)
out = self.complement_bytes(byte_arr, out)
return np.flip(out, length_axis)

def rev_comp_string(self, string: str):
Expand All @@ -150,27 +166,31 @@ def reverse_complement(
seqs: StrSeqType,
length_axis: Optional[int] = None,
ohe_axis: Optional[int] = None,
out: NDArray[np.bytes_] | None = None,
) -> NDArray[np.bytes_]: ...
@overload
def reverse_complement(
self,
seqs: NDArray[np.uint8],
length_axis: Optional[int] = None,
ohe_axis: Optional[int] = None,
out: NDArray[np.bytes_] | None = None,
) -> NDArray[np.uint8]: ...
@overload
def reverse_complement(
self,
seqs: SeqType,
length_axis: Optional[int] = None,
ohe_axis: Optional[int] = None,
out: NDArray[np.bytes_] | None = None,
) -> NDArray[Union[np.bytes_, np.uint8]]: ...
def reverse_complement(
self,
seqs: SeqType,
length_axis: Optional[int] = None,
ohe_axis: Optional[int] = None,
) -> NDArray[Union[np.bytes_, np.uint8]]:
out: NDArray[np.bytes_] | None = None,
) -> NDArray[np.bytes_ | np.uint8]:
"""Reverse complement a sequence.

Parameters
Expand All @@ -190,14 +210,20 @@ def reverse_complement(

seqs = cast_seqs(seqs)

if seqs.dtype == np.uint8: # OHE
if is_dtype(seqs, np.bytes_):
if length_axis is None:
length_axis = -1
return self.rev_comp_byte(seqs, length_axis, out)
elif is_dtype(seqs, np.uint8): # OHE
assert length_axis is not None
assert ohe_axis is not None
return np.flip(seqs, axis=(length_axis, ohe_axis))
_out = np.flip(seqs, axis=(length_axis, ohe_axis))
if out is not None:
out[:] = _out
_out = out
return _out
else:
if length_axis is None:
length_axis = -1
return self.rev_comp_byte(seqs, length_axis) # type: ignore
assert_never(seqs) # type: ignore


class AminoAlphabet:
Expand Down Expand Up @@ -334,3 +360,25 @@ def decode_ohe(
_alphabet = np.concatenate([self.aa_array, [unknown_char.encode("ascii")]])

return _alphabet[idx].reshape(shape)


DNA = NucleotideAlphabet("ACGT", "TGCA")


# * Monkey patch DNA instance with a faster complement function using
# * a static, const lookup table. The base method is slower because it uses a
# * dynamic lookup table.
def complement_bytes(
self: NucleotideAlphabet,
byte_arr: NDArray[np.bytes_],
out: NDArray[np.bytes_] | None = None,
) -> NDArray[np.bytes_]:
if out is None:
_out = out
else:
_out = out.view(np.uint8)
_out = ufunc_comp_dna(byte_arr.view(np.uint8), _out) # type: ignore
return _out.view("S1")


DNA.complement_bytes = MethodType(complement_bytes, DNA)
22 changes: 9 additions & 13 deletions python/seqpro/rag/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
if isinstance(data, RagParts):
content = _parts_to_content(data)
else:
content = _with_ragged(data, highlevel=False)
content = _as_ragged(data, highlevel=False)
super().__init__(content, behavior=deepcopy(ak.behavior))
self._parts = unbox(self)
type_parts: list[str] = []
Expand Down Expand Up @@ -232,7 +232,7 @@ def __getitem__(self, where):
if _n_var(arr) == 1:
return type(self)(arr)
else:
return _without_ragged(arr)
return _as_ak(arr)
else:
return arr

Expand Down Expand Up @@ -293,7 +293,7 @@ def reshape(self, *shape: int | None | tuple[int | None, ...]) -> Self:

def to_ak(self):
"""Convert to an Awkward array."""
arr = _without_ragged(self)
arr = _as_ak(self)
arr.behavior = None
return arr

Expand Down Expand Up @@ -331,12 +331,12 @@ def _n_var(arr: ak.Array) -> int:


@overload
def _with_ragged(
def _as_ragged(
arr: ak.Array | Content, highlevel: Literal[True] = True
) -> ak.Array: ...
@overload
def _with_ragged(arr: ak.Array | Content, highlevel: Literal[False]) -> Content: ...
def _with_ragged(arr: ak.Array | Content, highlevel: bool = True) -> ak.Array | Content:
def _as_ragged(arr: ak.Array | Content, highlevel: Literal[False]) -> Content: ...
def _as_ragged(arr: ak.Array | Content, highlevel: bool = True) -> ak.Array | Content:
def fn(layout: Content, **kwargs):
if isinstance(layout, (ListArray, ListOffsetArray)):
return ak.with_parameter(
Expand All @@ -350,16 +350,12 @@ def fn(layout: Content, **kwargs):


@overload
def _without_ragged(
def _as_ak(
arr: ak.Array | Ragged[DTYPE], highlevel: Literal[True] = True
) -> ak.Array: ...
@overload
def _without_ragged(
arr: ak.Array | Ragged[DTYPE], highlevel: Literal[False]
) -> Content: ...
def _without_ragged(
arr: ak.Array | Ragged[DTYPE], highlevel: bool = True
) -> ak.Array | Content:
def _as_ak(arr: ak.Array | Ragged[DTYPE], highlevel: Literal[False]) -> Content: ...
def _as_ak(arr: ak.Array | Ragged[DTYPE], highlevel: bool = True) -> ak.Array | Content:
def fn(layout, **kwargs):
if isinstance(layout, (ListArray, ListOffsetArray)):
return ak.with_parameter(layout, "__list__", None, highlevel=False)
Expand Down
Loading