From 542ef32cb3a1d6835f813146fdf67ac82d05776b Mon Sep 17 00:00:00 2001 From: Alexander Condello Date: Wed, 21 Jan 2026 13:56:47 -0800 Subject: [PATCH] Support numpy.newaxis while indexing array symbols --- dwave/optimization/_model.pyi | 9 +- dwave/optimization/_model.pyx | 103 ++++-------------- dwave/optimization/src/nodes/indexing.cpp | 4 +- dwave/optimization/symbols/indexing.pyx | 8 ++ dwave/optimization/utilities.py | 80 +++++++++++++- .../feature-newaxis-b42829ddab7379a7.yaml | 9 ++ tests/test_model.py | 51 --------- tests/test_symbols.py | 21 +++- tests/test_utilities.py | 81 +++++++++++++- 9 files changed, 220 insertions(+), 146 deletions(-) create mode 100644 releasenotes/notes/feature-newaxis-b42829ddab7379a7.yaml diff --git a/dwave/optimization/_model.pyi b/dwave/optimization/_model.pyi index 63130e6d..53adc213 100644 --- a/dwave/optimization/_model.pyi +++ b/dwave/optimization/_model.pyi @@ -16,6 +16,7 @@ import collections.abc import dataclasses import fractions import os +import types import typing import numpy @@ -26,6 +27,7 @@ import dwave.optimization.typing from dwave.optimization.utilities import _NoValue, _NoValueType _AxisLike: typing.TypeAlias = None | int | tuple[int, ...] +_IndexLike: typing.TypeAlias = int | slice | None | types.EllipsisType | ArraySymbol _InitialLike: typing.TypeAlias = None | _NoValueType | float _ShapeLike: typing.TypeAlias = typing.Union[int, collections.abc.Sequence[int]] @@ -109,12 +111,7 @@ class ArraySymbol(Symbol): def __bool__(self) -> typing.NoReturn: ... def __eq__(self, rhs: dwave.optimization.typing.ArraySymbolLike) -> symbols.Equal: ... def __ge__(self, rhs: dwave.optimization.typing.ArraySymbolLike) -> symbols.LessEqual: ... - - def __getitem__( - self, - index: typing.Union[Symbol, int, slice, tuple], - ) -> typing.Union[symbols.AdvancedIndexing, symbols.BasicIndexing, symbols.Permutation]: ... - + def __getitem__(self, index: _IndexLike | tuple[_IndexLike, ...]) -> ArraySymbol: ... def __iadd__(self, rhs: dwave.optimization.typing.ArraySymbolLike) -> symbols.NaryAdd: ... def __imul__(self, rhs: dwave.optimization.typing.ArraySymbolLike) -> symbols.NaryMultiply: ... def __iter__(self) -> typing.Iterator[ArraySymbol]: ... diff --git a/dwave/optimization/_model.pyx b/dwave/optimization/_model.pyx index 09fc1734..07732214 100644 --- a/dwave/optimization/_model.pyx +++ b/dwave/optimization/_model.pyx @@ -44,7 +44,13 @@ from dwave.optimization.libcpp.array cimport Array as cppArray, broadcast_shapes from dwave.optimization.libcpp.graph cimport DecisionNode as cppDecisionNode from dwave.optimization.states cimport States from dwave.optimization.states import StateView -from dwave.optimization.utilities import _file_object_arg, _lock, _NoValue, _TypeError_to_NotImplemented +from dwave.optimization.utilities import ( + _file_object_arg, + _lock, + _NoValue, + _split_indices, + _TypeError_to_NotImplemented, +) __all__ = [] @@ -1339,44 +1345,6 @@ cdef class Symbol: index = self.node_ptr.topological_index() return index if index >= 0 else None -def _split_indices(indices): - """Given a set of indices, made up of slices, integers, and array symbols, - create two consecutive indexing operations that can be passed to - BasicIndexing and AdvancedIndexing respectively. - """ - # this is pure-Python and could be moved out of this .pyx file at some point - - basic_indices = [] - advanced_indices = [] - - for index in indices: - if isinstance(index, numbers.Integral): - # Only basic handles numeric indices and it removes the axis so - # only one gets the index. - basic_indices.append(index) - elif isinstance(index, slice): - if index.start is None and index.stop is None and index.step is None: - # empty slice, both handle it - basic_indices.append(index) - advanced_indices.append(index) - else: - # Advanced can only handle empty slices, so we do basic first - basic_indices.append(index) - advanced_indices.append(slice(None)) - elif isinstance(index, (ArraySymbol, np.ndarray)): - # Only advanced handles arrays, it preserves the axis so basic gets - # an empty slice. - # We allow np.ndarray here for testing purposes. They are not (yet) - # natively handled by AdvancedIndexingNode. - basic_indices.append(slice(None)) - advanced_indices.append(index) - - else: - # this should be checked by the calling function, but just in case - raise RuntimeError("unexpected index type") - - return tuple(basic_indices), tuple(advanced_indices) - # dev note: most documentation is in the `ArraySymbol.info()` method. @dataclasses.dataclass(frozen=True) @@ -1512,54 +1480,31 @@ cdef class ArraySymbol(Symbol): return dwave.optimization.mathematical.less_equal(rhs, self) def __getitem__(self, index): - import dwave.optimization.symbols # avoid circular import - if isinstance(index, tuple): - index = list(index) - - # for all indexing styles, empty slices are padded to fill out the - # number of dimension - while len(index) < self.ndim(): - index.append(slice(None)) - - # replace any array-like indices with constants - for i in range(len(index)): - if isinstance(index[i], numbers.Integral): - continue - try: - index[i] = _as_array_symbol(self.model, index[i]) - except (TypeError, ValueError): - pass + if not isinstance(index, tuple): + return self[(index,)] - if all(isinstance(idx, (slice, numbers.Integral)) for idx in index): - # Basic indexing - # https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing - return dwave.optimization.symbols.BasicIndexing(self, *index) + import dwave.optimization.symbols # avoid circular import - elif all(isinstance(idx, ArraySymbol) or - (isinstance(idx, slice) and idx.start is None and idx.stop is None and idx.step is None) - for idx in index): - # Advanced indexing - # https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing + newaxes, basic, advanced = _split_indices(self.shape(), index) - return dwave.optimization.symbols.AdvancedIndexing(self, *index) + out = self - elif all(isinstance(idx, (ArraySymbol, slice, numbers.Integral)) for idx in index): - # Combined indexing - # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + if newaxes: + raise NotImplementedError - # We handle this by doing basic and then advanced indexing. In principle the other - # order may be more efficient in some cases, but for now let's do the simple thing + if not all(isinstance(index, slice) and index == slice(None) for index in basic): + out = dwave.optimization.symbols.BasicIndexing(out, *basic) - basic_indices, advanced_indices = _split_indices(index) - basic = dwave.optimization.symbols.BasicIndexing(self, *basic_indices) - return dwave.optimization.symbols.AdvancedIndexing(basic, *advanced_indices) + if not all(isinstance(index, slice) and index == slice(None) for index in advanced): + # Do the type promotion to ArraySymbol + try: + adv = [index if isinstance(index, slice) else _as_array_symbol(self.model, index) for index in advanced] + except ValueError as err: + raise IndexError(err) - else: - # this error message is chosen to be similar to NumPy's - raise IndexError("only integers, slices (`:`), arrays, and array symbols are valid indices") + out = dwave.optimization.symbols.AdvancedIndexing(out, *adv) - else: - return self[(index,)] + return out @_TypeError_to_NotImplemented def __iadd__(self, rhs): diff --git a/dwave/optimization/src/nodes/indexing.cpp b/dwave/optimization/src/nodes/indexing.cpp index 44f91cf9..765b4f36 100644 --- a/dwave/optimization/src/nodes/indexing.cpp +++ b/dwave/optimization/src/nodes/indexing.cpp @@ -234,12 +234,12 @@ struct AdvancedIndexingNode::IndexParser_ { assert(array_ptr->ndim() >= 0); // Should always be true if (static_cast(array_ptr->ndim()) < indices_.size()) { // NumPy handles this case, we could as well - throw std::invalid_argument(std::string("too few indices for array: array is ") + + throw std::invalid_argument(std::string("too many indices for array: array is ") + std::to_string(array_ptr->ndim()) + "-dimensional, but " + std::to_string(indices_.size()) + " were indexed"); } if (static_cast(array_ptr->ndim()) > indices_.size()) { - throw std::invalid_argument(std::string("too many indices for array: array is ") + + throw std::invalid_argument(std::string("too few indices for array: array is ") + std::to_string(array_ptr->ndim()) + "-dimensional, but " + std::to_string(indices_.size()) + " were indexed"); } diff --git a/dwave/optimization/symbols/indexing.pyx b/dwave/optimization/symbols/indexing.pyx index 42404272..c042f7e1 100644 --- a/dwave/optimization/symbols/indexing.pyx +++ b/dwave/optimization/symbols/indexing.pyx @@ -60,6 +60,10 @@ cdef class AdvancedIndexing(ArraySymbol): cppindices.emplace_back(array_index.array_ptr) + # If we had too few indexers, pad the remaining with empty slices + while cppindices.size() < array.ndim(): + cppindices.emplace_back(Slice()) + self.ptr = model._graph.emplace_node[AdvancedIndexingNode](array.array_ptr, cppindices) self.initialize_arraynode(model, self.ptr) @@ -169,6 +173,10 @@ cdef class BasicIndexing(ArraySymbol): else: cppindices.emplace_back((index)) + # If we had too few indexers, pad the remaining with empty slices + while cppindices.size() < array.ndim(): + cppindices.emplace_back(Slice()) + self.ptr = model._graph.emplace_node[BasicIndexingNode](array.array_ptr, cppindices) self.initialize_arraynode(model, self.ptr) diff --git a/dwave/optimization/utilities.py b/dwave/optimization/utilities.py index a5350ac2..18899f9f 100644 --- a/dwave/optimization/utilities.py +++ b/dwave/optimization/utilities.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import functools +import numbers +import os +import types __all__ = [] @@ -76,6 +78,82 @@ def _method(obj, *args, **kwargs): return _method +def _split_indices( + shape: tuple[int, ...], + index: tuple[int | slice | None | types.EllipsisType | object, ...], +): + """Given a combined indexing operation, split into several steps. + + Args: + shape: The shape of the indexed array. + index: A tuple of indexers. + + Returns: + In order that they should be applied: + * A list of new axes that should be added, as with ``np.expand_dims()``. + * Indexers that can be given to ``BasicIndexing`` + * Indexers that can be given to ``AdvancedIndexing``. + + """ + # developer note: there is a known issue where splitting the indexers in this + # way cannot account for all cases. + # See https://github.com/dwavesystems/dwave-optimization/issues/465 for more + # information. + + indices: list = list(index) + + # Handle the ellipses if it's present + if (count := sum(idx is ... for idx in indices)) > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + elif count == 1: + # We have an ellipses, so we need to replace it with empty slice(s) until + # we hit the correct length + + # First, find where the ellipses is + for loc, index in enumerate(indices): + if index is ...: + break + else: + raise RuntimeError # shouldn't be able to get here + + # Now that we know where the ellipses is, we remove it and then replace + # it with empty slices until we hit our desired ndim. Though we need + # to make sure not to count the newaxes towards that count + indices.pop(loc) + for _ in range(sum(idx is not None for idx in indices), len(shape)): + indices.insert(loc, slice(None)) + + # Now divide everything that's remaining between basic and advanced indexing + newaxes: list[int] = [] + basic: list[slice | int] = [] + advanced: list[slice | object] = [] + for i, index in enumerate(indices): + if index is None: + # We'll insert the new axis before calling basid/advanced indexing + newaxes.append(i) + basic.append(slice(None)) + advanced.append(slice(None)) + elif isinstance(index, numbers.Integral): + # Only basic handles numeric indices and it removes the axis so + # only basic gets the index + basic.append(index) + elif isinstance(index, slice) and index == slice(None): + # Empty slices are handled by both basic and advanced indexing + basic.append(slice(None)) + advanced.append(slice(None)) + elif isinstance(index, slice): + # Non-empty slice are only handled by basic indexing + basic.append(index) + advanced.append(slice(None)) + else: + # For anything else, we defer to advanced indexing for the type + # checking + basic.append(slice(None)) + advanced.append(index) + + return tuple(newaxes), tuple(basic), tuple(advanced) + + def _TypeError_to_NotImplemented(f): """Convert any TypeErrors raised by the given function into NotImplemented""" @functools.wraps(f) diff --git a/releasenotes/notes/feature-newaxis-b42829ddab7379a7.yaml b/releasenotes/notes/feature-newaxis-b42829ddab7379a7.yaml new file mode 100644 index 00000000..67f4bdda --- /dev/null +++ b/releasenotes/notes/feature-newaxis-b42829ddab7379a7.yaml @@ -0,0 +1,9 @@ +--- +features: + - | + Support ``numpy.newaxis`` and ``...`` while indexing array symbols. +issues: + - | + In some cases, combined indexing operations on high dimensional arrays will + not match the behavior of NumPy. + See `#465 `_. diff --git a/tests/test_model.py b/tests/test_model.py index d7a05abd..41de0f34 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -186,57 +186,6 @@ class UnknownType(): numerator, denominator = two_over_x.iter_predecessors() self.assertEqual(denominator.id(), x_.id()) - class IndexTester: - def __init__(self, array): - self.array = array - - def __getitem__(self, index): - if not isinstance(index, tuple): - return self[(index,)] - i0, i1 = dwave.optimization._model._split_indices(index) - np.testing.assert_array_equal(self.array[index], self.array[i0][i1]) - - def test_split_indices(self): - a1d = np.arange(5) - a2d = np.arange(5 * 6).reshape(5, 6) - a4d = np.arange(5 * 6 * 7 * 8).reshape(5, 6, 7, 8) - - for arr in [a1d, a2d, a4d]: - test = self.IndexTester(arr) - - test[:] - test[0] - test[np.asarray([0, 1, 2])] - test[1:] - test[1:4:2] - - if arr.ndim < 2: - continue - - test[:, :] - test[0, 1] - test[:2, :] - test[:, 1::2] - test[0, :] - test[:, 3] - test[::2, 2] - test[np.asarray([0, 2, 1]), np.asarray([0, 0, 0])] - test[np.asarray([0, 2, 1]), :] - test[np.asarray([0, 2, 1]), 3] - test[:, np.asarray([0, 2, 1])] - test[3, np.asarray([0, 2, 1])] - - if arr.ndim < 3: - continue - - test[:, :, :, :] - test[0, 1, 2, 4] - test[::2, 0, ::2, :] - - # two different types of combined indexing - test[np.asarray([0, 2, 1]), :, np.asarray([0, 0, 0]), np.asarray([0, 0, 0])] - test[np.asarray([0, 2, 1]), np.asarray([0, 0, 0]), np.asarray([0, 0, 0]), :] - class TestModel(unittest.TestCase): def test(self): diff --git a/tests/test_symbols.py b/tests/test_symbols.py index cbd7e935..f540f061 100644 --- a/tests/test_symbols.py +++ b/tests/test_symbols.py @@ -342,7 +342,7 @@ def test_constant_promotion(self): "index may not contain non-integer values for axis 0"): x[1, [0, 1.1], model.constant([0, 3])] - with self.assertRaisesRegex(IndexError, "only integers, slices"): + with self.assertRaisesRegex(IndexError, "array must not contain infs or NaNs"): x[1, [0, float("inf")], model.constant([0, 3])] @@ -624,7 +624,10 @@ def test_infer_indices_1d(self): model = Model() x = model.binary(10) - self.assertEqual(x[:]._infer_indices(), (slice(0, 10, 1),)) + self.assertEqual( + dwave.optimization.symbols.BasicIndexing(x)._infer_indices(), + (slice(0, 10, 1),), + ) self.assertEqual(x[1::]._infer_indices(), (slice(1, 10, 1),)) self.assertEqual(x[:3:]._infer_indices(), (slice(0, 3, 1),)) self.assertEqual(x[::2]._infer_indices(), (slice(0, 10, 2),)) @@ -648,7 +651,9 @@ def test_infer_indices_2d(self): self.assertEqual(x[3, 4::2]._infer_indices(), (3, slice(4, 6, 2))) self.assertEqual(x[3, 4:4:2]._infer_indices(), (3, slice(4, 4, 2))) - self.assertEqual(x[:, :]._infer_indices(), (slice(0, 5, 1), slice(0, 6, 1))) + self.assertEqual(dwave.optimization.symbols.BasicIndexing(x)._infer_indices(), + (slice(0, 5, 1), slice(0, 6, 1)), + ) self.assertEqual(x[::2, :]._infer_indices(), (slice(0, 5, 2), slice(0, 6, 1))) self.assertEqual(x[:, ::2]._infer_indices(), (slice(0, 5, 1), slice(0, 6, 2))) self.assertEqual(x[2:, ::2]._infer_indices(), (slice(2, 5, 1), slice(0, 6, 2))) @@ -661,7 +666,10 @@ def test_infer_indices_3d(self): self.assertEqual(x[0, 4, 0]._infer_indices(), (0, 4, 0)) self.assertEqual(x[0, 0, 0]._infer_indices(), (0, 0, 0)) - self.assertEqual(x[:]._infer_indices(), (slice(0, 5, 1), slice(0, 6, 1), slice(0, 7, 1))) + self.assertEqual( + dwave.optimization.symbols.BasicIndexing(x)._infer_indices(), + (slice(0, 5, 1), slice(0, 6, 1), slice(0, 7, 1)), + ) self.assertEqual(x[:, 3, :]._infer_indices(), (slice(0, 5, 1), 3, slice(0, 7, 1))) self.assertEqual(x[:, :, 3]._infer_indices(), (slice(0, 5, 1), slice(0, 6, 1), 3)) @@ -671,7 +679,10 @@ def test_infer_indices_1d_dynamic(self): model = Model() x = model.set(10) - self.assertEqual(x[:]._infer_indices(), (slice(0, MAX, 1),)) + self.assertEqual( + dwave.optimization.symbols.BasicIndexing(x)._infer_indices(), + (slice(0, MAX, 1),), + ) self.assertEqual(x[::2]._infer_indices(), (slice(0, MAX, 2),)) self.assertEqual(x[5:2:2]._infer_indices(), (slice(5, 2, 2),)) self.assertEqual(x[:2:]._infer_indices(), (slice(0, 2, 1),)) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 2237f879..9e743b6b 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -14,9 +14,86 @@ import unittest -from dwave.optimization.utilities import _NoValue, _NoValueType +import numpy as np + +import dwave.optimization class Test_NoValue(unittest.TestCase): def test(self): - self.assertIs(_NoValue, _NoValueType()) + self.assertIs( + dwave.optimization.utilities._NoValue, + dwave.optimization.utilities._NoValueType(), + ) + + +class Test_split_indices(unittest.TestCase): + class IndexTester: + def __init__(self, array): + self.array = array + + def __getitem__(self, index): + if not isinstance(index, tuple): + return self[(index,)] + dims, basic, advanced = dwave.optimization.utilities._split_indices(self.array.shape, index) + np.testing.assert_array_equal(self.array[index], np.expand_dims(self.array, dims)[basic][advanced]) + + def test_equivalence(self): + a1d = np.arange(5) + a2d = np.arange(5 * 6).reshape(5, 6) + a4d = np.arange(5 * 6 * 7 * 8).reshape(5, 6, 7, 8) + + for arr in [a1d, a2d, a4d]: + test = self.IndexTester(arr) + + test[:] + test[0] + test[np.int8(0)] + test[np.asarray([0, 1, 2])] + test[1:] + test[1:4:2] + test[:, np.newaxis] + test[np.newaxis, :] + test[np.newaxis, :, np.newaxis] + test[...] + test[..., 0] + test[:, ...] + test[np.newaxis, ...] + test[..., np.newaxis] + test[:, ..., np.newaxis] + test[..., :, np.newaxis] + + if arr.ndim < 2: + continue + + test[:, :] + test[0, 1] + test[:2, :] + test[:, 1::2] + test[0, :] + test[:, 3] + test[::2, 2] + test[np.asarray([0, 2, 1]), np.asarray([0, 0, 0])] + test[np.asarray([0, 2, 1]), :] + test[np.asarray([0, 2, 1]), 3] + test[:, np.asarray([0, 2, 1])] + test[3, np.asarray([0, 2, 1])] + test[..., 3, np.asarray([0, 2, 1])] + test[np.newaxis, ..., 3, np.asarray([0, 2, 1])] + + if arr.ndim < 3: + continue + + test[:, :, :, :] + test[0, 1, 2, 4] + test[::2, 0, ::2, :] + test[:, ..., :] + test[:, ..., ::2] + test[np.newaxis, :] + test[..., 0, np.newaxis, 2] + + # known issue: https://github.com/dwavesystems/dwave-optimization/issues/465 + # test[:, 0, :, [2]] + + test[np.asarray([0, 2, 1]), :, np.asarray([0, 0, 0]), np.asarray([0, 0, 0])] + test[np.asarray([0, 2, 1]), np.asarray([0, 0, 0]), np.asarray([0, 0, 0]), :]