Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 67 additions & 30 deletions arkouda/pandas/extension/_arkouda_extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from numpy.typing import NDArray
from pandas.api.extensions import ExtensionArray
from pandas.core.arraylike import OpsMixin
from pandas.core.dtypes.base import ExtensionDtype
from typing_extensions import Self

from arkouda.numpy.dtypes import all_scalars
Expand Down Expand Up @@ -264,38 +265,48 @@ def _from_sequence(
copy: bool = False,
) -> "ArkoudaExtensionArray":
"""
Construct an Arkouda-backed ExtensionArray from Arkouda objects or
Python/NumPy scalars.
Construct an Arkouda-backed pandas ExtensionArray from Arkouda objects
or Python/NumPy scalars.

This factory inspects ``scalars`` and returns an instance of the
appropriate concrete subclass:
This method acts as a **factory and dispatcher** for Arkouda-backed
ExtensionArray implementations. It inspects ``scalars``—or, when needed,
the result of converting ``scalars`` to an Arkouda server-side object—
and returns an instance of the appropriate concrete subclass:

* :class:`ArkoudaArray` for :class:`pdarray`
* :class:`ArkoudaArray` for numeric :class:`pdarray`
* :class:`ArkoudaStringArray` for :class:`Strings`
* :class:`ArkoudaCategoricalArray` for :class:`Categorical`
* :class:`ArkoudaCategoricalArray` for pandas-style
:class:`~arkouda.pandas.categorical.Categorical`

If ``scalars`` is **not** already an Arkouda server-side array, it is
interpreted as a sequence of Python/NumPy scalars, converted into a
server-side ``pdarray`` via :func:`arkouda.numpy.pdarraycreation.array`,
and wrapped in :class:`ArkoudaArray`.
This method is the primary construction hook used by pandas when creating
Arkouda-backed arrays via ``pd.array(..., dtype="ak")``.

Parameters
----------
scalars : object
Either an Arkouda array type (``pdarray``, ``Strings``,
or ``Categorical``) or a sequence of Python/NumPy scalars.
Either an Arkouda server-side object (``pdarray``, ``Strings``, or
``Categorical``) or a sequence of Python/NumPy scalars.
dtype : object, optional
Ignored. Present for pandas API compatibility.
Pandas-provided dtype argument. The generic Arkouda dtype
(``"ak"`` or :class:`ArkoudaDtype`) is interpreted as a request for
backend inference and is ignored during server-side construction.
Concrete Arkouda dtypes are not interpreted here.
copy : bool, default False
Ignored. Present for pandas API compatibility.
Present for pandas API compatibility. Currently ignored; Arkouda
array construction semantics determine copying behavior.

Returns
-------
ArkoudaExtensionArray
An instance of :class:`ArkoudaArray`,
:class:`ArkoudaStringArray`, or
:class:`ArkoudaCategoricalArray`, depending on the type of
``scalars``.
:class:`ArkoudaCategoricalArray`, depending on the type of the
resulting Arkouda server-side object.

Raises
------
TypeError
If conversion produces an unsupported Arkouda object type.

Examples
--------
Expand All @@ -315,7 +326,7 @@ def _from_sequence(
>>> ea
ArkoudaStringArray(['red', 'green', 'blue'])

From Python scalars:
From Python scalars (type inferred server-side):

>>> ea = ArkoudaExtensionArray._from_sequence([10, 20, 30])
>>> ea
Expand All @@ -328,26 +339,52 @@ def _from_sequence(
>>> ea
ArkoudaArray([1 2 3])
"""
# Local imports to avoid circular dependencies at module import time.
from arkouda.numpy.pdarrayclass import pdarray
from arkouda.numpy.pdarraycreation import array as ak_array
from arkouda.numpy.strings import Strings
from arkouda.pandas.categorical import Categorical
from arkouda.pandas.categorical import Categorical as ak_Categorical
from arkouda.pandas.extension._arkouda_array import ArkoudaArray
from arkouda.pandas.extension._arkouda_categorical_array import ArkoudaCategoricalArray
from arkouda.pandas.extension._arkouda_string_array import ArkoudaStringArray
from arkouda.pandas.extension._dtypes import ArkoudaCategoricalDtype, ArkoudaDtype

# dtype may be:
# - None
# - a string like "ak" or "ak_int64"
# - an ExtensionDtype instance (e.g. ArkoudaDtype())

if dtype == "ak" or isinstance(dtype, ArkoudaDtype):
dtype = None

if (
isinstance(dtype, ExtensionDtype)
and not isinstance(dtype, ArkoudaCategoricalDtype)
and hasattr(dtype, "_numpy_dtype")
):
dtype = dtype._numpy_dtype

# ---------------------------------------------------------------------
# Convert to Arkouda once, then dispatch on result type
# TODO: streamline Categorical handling
if not isinstance(scalars, ak_Categorical):
if isinstance(dtype, ArkoudaCategoricalDtype):
ak_obj = ak_Categorical(ak_array(scalars, dtype="str_"))
else:
ak_obj = ak_array(scalars, dtype=dtype)
else:
if dtype is None or isinstance(dtype, ArkoudaCategoricalDtype):
ak_obj = scalars
else:
ak_obj = ak_array(scalars.to_strings(), dtype=dtype)

if isinstance(ak_obj, pdarray):
return ArkoudaArray(ak_obj)
if isinstance(ak_obj, Strings):
return ArkoudaStringArray(ak_obj)
if isinstance(ak_obj, ak_Categorical):
return ArkoudaCategoricalArray(ak_obj)

# Fast path: already an Arkouda column. Pick the matching subclass.
if isinstance(scalars, pdarray):
return ArkoudaArray(scalars)
if isinstance(scalars, Strings):
return ArkoudaStringArray(scalars)
if isinstance(scalars, Categorical):
return ArkoudaCategoricalArray(scalars)

# Fallback: treat as a sequence of scalars and build a pdarray.
data = ak_array(scalars)
return ArkoudaArray(data)
raise TypeError(f"Unsupported Arkouda construction result: {type(ak_obj).__name__}")

def _fill_missing(self, mask, fill_value):
raise NotImplementedError("Subclasses must implement _fill_missing")
Expand Down
29 changes: 29 additions & 0 deletions arkouda/pandas/extension/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,35 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.name!r})"


# ---- Generic dtype -------------------------------------------------------------


@register_extension_dtype
class ArkoudaDtype(ExtensionDtype):
"""
Generic Arkouda-backed dtype for pandas construction.

Using dtype="ak" triggers ArkoudaExtensionArray._from_sequence, which
dispatches to ArkoudaArray / ArkoudaStringArray / ArkoudaCategoricalArray.
"""

name = "ak"
type = object # pandas requires something
kind = "O"

@classmethod
def construct_from_string(cls, string):
if string == "ak":
return cls()
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")

def construct_array_type(self):
# Important: return the base class that implements factory dispatch.
from arkouda.pandas.extension._arkouda_extension_array import ArkoudaExtensionArray

return ArkoudaExtensionArray


# ---- Concrete dtypes --------------------------------------------------------


Expand Down
59 changes: 53 additions & 6 deletions tests/pandas/extension/dtypes_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pandas as pd
import pytest

import arkouda as ak

from arkouda.pandas.extension import ArkoudaArray, ArkoudaCategoricalArray, ArkoudaStringArray

# Module under test
Expand Down Expand Up @@ -299,8 +301,6 @@ def test_series_roundtrip_with_arkouda_array(self, dtype_cls, data, expect_dtype
import pandas as pd
import pandas.testing as pdt

import arkouda as ak

from arkouda.pandas.extension._arkouda_array import ArkoudaArray

ak_arr = ak.array(data, dtype=dtype_cls().name)
Expand All @@ -314,8 +314,6 @@ def test_series_roundtrip_with_arkouda_array(self, dtype_cls, data, expect_dtype
pdt.assert_series_equal(s.astype(object), expected, check_names=False)

def test_series_with_strings_dtype(self):
import arkouda as ak

from arkouda.pandas.extension._arkouda_string_array import ArkoudaStringArray

a = ak.array(["a", "b", ""])
Expand All @@ -326,8 +324,6 @@ def test_series_with_strings_dtype(self):
assert s.iloc[2] == ""

def test_series_with_categorical_dtype(self):
import arkouda as ak

from arkouda.pandas.extension._arkouda_categorical_array import (
ArkoudaCategoricalArray,
)
Expand All @@ -338,3 +334,54 @@ def test_series_with_categorical_dtype(self):
assert isinstance(s.dtype, ArkoudaCategoricalDtype)
# categories round-trip to Python scalars on materialization
assert list(s.to_numpy()) == ["x", "y", "x"]


class TestArkoudaGenericDtypesExtension:
@pytest.mark.parametrize(
"data, expected_dtype_cls",
[
(np.array([1, 2, 3], dtype=np.int64), ArkoudaInt64Dtype),
(np.array([1.0, 2.0, 3.0], dtype=np.float64), ArkoudaFloat64Dtype),
],
)
def test_pd_array_dtype_ak_dispatches_numeric(self, data, expected_dtype_cls):
arr = pd.array(data, dtype="ak")
assert isinstance(arr, ArkoudaArray)
assert isinstance(arr.dtype, expected_dtype_cls)

def test_pd_array_dtype_ak_dispatches_strings(self):
arr = pd.array(["red", "green", "blue"], dtype="ak")
assert isinstance(arr, ArkoudaStringArray)
assert isinstance(arr.dtype, ArkoudaStringDtype)

def test_pd_array_dtype_ak_dispatches_categorical_from_ak_pandas_object(self):
# ArkoudaExtensionArray dispatch checks for arkouda.pandas.categorical.Categorical
from arkouda.pandas.categorical import Categorical as PandasCategorical

cat = PandasCategorical(ak.array(["a", "b", "a", "c"]))
arr = pd.array(cat, dtype="ak")

assert isinstance(arr, ArkoudaCategoricalArray)
assert isinstance(arr.dtype, ArkoudaCategoricalDtype)

def test_series_dtype_ak_dispatches_numeric(self):
s = pd.Series(np.array([10, 20, 30], dtype=np.int64), dtype="ak")
assert isinstance(s.array, ArkoudaArray)
assert isinstance(s.dtype, ArkoudaInt64Dtype)

def test_series_dtype_ak_dispatches_strings(self):
s = pd.Series(["x", "y", "z"], dtype="ak")
assert isinstance(s.array, ArkoudaStringArray)
assert isinstance(s.dtype, ArkoudaStringDtype)

def test_series_dtype_ak_dispatches_categorical_from_ak_pandas_object(self):
from arkouda.pandas.categorical import Categorical as PandasCategorical

cat = PandasCategorical(ak.array(["dog", "cat", "dog"]))

# Key: construct the EA first; Series will accept the EA without iterating cat
arr = pd.array(cat, dtype="ak")
s = pd.Series(arr)

assert isinstance(s.array, ArkoudaCategoricalArray)
assert isinstance(s.dtype, ArkoudaCategoricalDtype)