Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions dwave/optimization/_model.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import collections.abc
import dataclasses
import fractions
import os
import types
import typing

import numpy
Expand All @@ -26,6 +27,7 @@ import dwave.optimization.typing
from dwave.optimization.utilities import _NoValue, _NoValueType

_AxisLike: typing.TypeAlias = None | int | tuple[int, ...]
_IndexLike: typing.TypeAlias = int | slice | None | types.EllipsisType | ArraySymbol
_InitialLike: typing.TypeAlias = None | _NoValueType | float
_ShapeLike: typing.TypeAlias = typing.Union[int, collections.abc.Sequence[int]]

Expand Down Expand Up @@ -109,12 +111,7 @@ class ArraySymbol(Symbol):
def __bool__(self) -> typing.NoReturn: ...
def __eq__(self, rhs: dwave.optimization.typing.ArraySymbolLike) -> symbols.Equal: ...
def __ge__(self, rhs: dwave.optimization.typing.ArraySymbolLike) -> symbols.LessEqual: ...

def __getitem__(
self,
index: typing.Union[Symbol, int, slice, tuple],
) -> typing.Union[symbols.AdvancedIndexing, symbols.BasicIndexing, symbols.Permutation]: ...

def __getitem__(self, index: _IndexLike | tuple[_IndexLike, ...]) -> ArraySymbol: ...
def __iadd__(self, rhs: dwave.optimization.typing.ArraySymbolLike) -> symbols.NaryAdd: ...
def __imul__(self, rhs: dwave.optimization.typing.ArraySymbolLike) -> symbols.NaryMultiply: ...
def __iter__(self) -> typing.Iterator[ArraySymbol]: ...
Expand Down
103 changes: 24 additions & 79 deletions dwave/optimization/_model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ from dwave.optimization.libcpp.array cimport Array as cppArray, broadcast_shapes
from dwave.optimization.libcpp.graph cimport DecisionNode as cppDecisionNode
from dwave.optimization.states cimport States
from dwave.optimization.states import StateView
from dwave.optimization.utilities import _file_object_arg, _lock, _NoValue, _TypeError_to_NotImplemented
from dwave.optimization.utilities import (
_file_object_arg,
_lock,
_NoValue,
_split_indices,
_TypeError_to_NotImplemented,
)

__all__ = []

Expand Down Expand Up @@ -1339,44 +1345,6 @@ cdef class Symbol:
index = self.node_ptr.topological_index()
return index if index >= 0 else None

def _split_indices(indices):
"""Given a set of indices, made up of slices, integers, and array symbols,
create two consecutive indexing operations that can be passed to
BasicIndexing and AdvancedIndexing respectively.
"""
# this is pure-Python and could be moved out of this .pyx file at some point

basic_indices = []
advanced_indices = []

for index in indices:
if isinstance(index, numbers.Integral):
# Only basic handles numeric indices and it removes the axis so
# only one gets the index.
basic_indices.append(index)
elif isinstance(index, slice):
if index.start is None and index.stop is None and index.step is None:
# empty slice, both handle it
basic_indices.append(index)
advanced_indices.append(index)
else:
# Advanced can only handle empty slices, so we do basic first
basic_indices.append(index)
advanced_indices.append(slice(None))
elif isinstance(index, (ArraySymbol, np.ndarray)):
# Only advanced handles arrays, it preserves the axis so basic gets
# an empty slice.
# We allow np.ndarray here for testing purposes. They are not (yet)
# natively handled by AdvancedIndexingNode.
basic_indices.append(slice(None))
advanced_indices.append(index)

else:
# this should be checked by the calling function, but just in case
raise RuntimeError("unexpected index type")

return tuple(basic_indices), tuple(advanced_indices)


# dev note: most documentation is in the `ArraySymbol.info()` method.
@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -1512,54 +1480,31 @@ cdef class ArraySymbol(Symbol):
return dwave.optimization.mathematical.less_equal(rhs, self)

def __getitem__(self, index):
import dwave.optimization.symbols # avoid circular import
if isinstance(index, tuple):
index = list(index)

# for all indexing styles, empty slices are padded to fill out the
# number of dimension
while len(index) < self.ndim():
index.append(slice(None))

# replace any array-like indices with constants
for i in range(len(index)):
if isinstance(index[i], numbers.Integral):
continue
try:
index[i] = _as_array_symbol(self.model, index[i])
except (TypeError, ValueError):
pass
if not isinstance(index, tuple):
return self[(index,)]

if all(isinstance(idx, (slice, numbers.Integral)) for idx in index):
# Basic indexing
# https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing
return dwave.optimization.symbols.BasicIndexing(self, *index)
import dwave.optimization.symbols # avoid circular import

elif all(isinstance(idx, ArraySymbol) or
(isinstance(idx, slice) and idx.start is None and idx.stop is None and idx.step is None)
for idx in index):
# Advanced indexing
# https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing
newaxes, basic, advanced = _split_indices(self.shape(), index)

return dwave.optimization.symbols.AdvancedIndexing(self, *index)
out = self

elif all(isinstance(idx, (ArraySymbol, slice, numbers.Integral)) for idx in index):
# Combined indexing
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
if newaxes:
raise NotImplementedError

# We handle this by doing basic and then advanced indexing. In principle the other
# order may be more efficient in some cases, but for now let's do the simple thing
if not all(isinstance(index, slice) and index == slice(None) for index in basic):
out = dwave.optimization.symbols.BasicIndexing(out, *basic)

basic_indices, advanced_indices = _split_indices(index)
basic = dwave.optimization.symbols.BasicIndexing(self, *basic_indices)
return dwave.optimization.symbols.AdvancedIndexing(basic, *advanced_indices)
if not all(isinstance(index, slice) and index == slice(None) for index in advanced):
# Do the type promotion to ArraySymbol
try:
adv = [index if isinstance(index, slice) else _as_array_symbol(self.model, index) for index in advanced]
except ValueError as err:
raise IndexError(err)

else:
# this error message is chosen to be similar to NumPy's
raise IndexError("only integers, slices (`:`), arrays, and array symbols are valid indices")
out = dwave.optimization.symbols.AdvancedIndexing(out, *adv)

else:
return self[(index,)]
return out

@_TypeError_to_NotImplemented
def __iadd__(self, rhs):
Expand Down
4 changes: 2 additions & 2 deletions dwave/optimization/src/nodes/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,12 @@ struct AdvancedIndexingNode::IndexParser_ {
assert(array_ptr->ndim() >= 0); // Should always be true
if (static_cast<std::size_t>(array_ptr->ndim()) < indices_.size()) {
// NumPy handles this case, we could as well
throw std::invalid_argument(std::string("too few indices for array: array is ") +
throw std::invalid_argument(std::string("too many indices for array: array is ") +
std::to_string(array_ptr->ndim()) + "-dimensional, but " +
std::to_string(indices_.size()) + " were indexed");
}
if (static_cast<std::size_t>(array_ptr->ndim()) > indices_.size()) {
throw std::invalid_argument(std::string("too many indices for array: array is ") +
throw std::invalid_argument(std::string("too few indices for array: array is ") +
std::to_string(array_ptr->ndim()) + "-dimensional, but " +
std::to_string(indices_.size()) + " were indexed");
}
Expand Down
8 changes: 8 additions & 0 deletions dwave/optimization/symbols/indexing.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ cdef class AdvancedIndexing(ArraySymbol):

cppindices.emplace_back(array_index.array_ptr)

# If we had too few indexers, pad the remaining with empty slices
while cppindices.size() < array.ndim():
cppindices.emplace_back(Slice())

self.ptr = model._graph.emplace_node[AdvancedIndexingNode](array.array_ptr, cppindices)

self.initialize_arraynode(model, self.ptr)
Expand Down Expand Up @@ -169,6 +173,10 @@ cdef class BasicIndexing(ArraySymbol):
else:
cppindices.emplace_back(<Py_ssize_t>(index))

# If we had too few indexers, pad the remaining with empty slices
while cppindices.size() < array.ndim():
cppindices.emplace_back(Slice())

self.ptr = model._graph.emplace_node[BasicIndexingNode](array.array_ptr, cppindices)

self.initialize_arraynode(model, self.ptr)
Expand Down
80 changes: 79 additions & 1 deletion dwave/optimization/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import functools
import numbers
import os
import types

__all__ = []

Expand Down Expand Up @@ -76,6 +78,82 @@ def _method(obj, *args, **kwargs):
return _method


def _split_indices(
shape: tuple[int, ...],
index: tuple[int | slice | None | types.EllipsisType | object, ...],
):
"""Given a combined indexing operation, split into several steps.

Args:
shape: The shape of the indexed array.
index: A tuple of indexers.

Returns:
In order that they should be applied:
* A list of new axes that should be added, as with ``np.expand_dims()``.
* Indexers that can be given to ``BasicIndexing``
* Indexers that can be given to ``AdvancedIndexing``.

"""
# developer note: there is a known issue where splitting the indexers in this
# way cannot account for all cases.
# See https://github.com/dwavesystems/dwave-optimization/issues/465 for more
# information.

indices: list = list(index)

# Handle the ellipses if it's present
if (count := sum(idx is ... for idx in indices)) > 1:
raise IndexError("an index can only have a single ellipsis ('...')")
elif count == 1:
# We have an ellipses, so we need to replace it with empty slice(s) until
# we hit the correct length

# First, find where the ellipses is
for loc, index in enumerate(indices):
if index is ...:
break
else:
raise RuntimeError # shouldn't be able to get here

# Now that we know where the ellipses is, we remove it and then replace
# it with empty slices until we hit our desired ndim. Though we need
# to make sure not to count the newaxes towards that count
indices.pop(loc)
for _ in range(sum(idx is not None for idx in indices), len(shape)):
indices.insert(loc, slice(None))

# Now divide everything that's remaining between basic and advanced indexing
newaxes: list[int] = []
basic: list[slice | int] = []
advanced: list[slice | object] = []
for i, index in enumerate(indices):
if index is None:
# We'll insert the new axis before calling basid/advanced indexing
newaxes.append(i)
basic.append(slice(None))
advanced.append(slice(None))
elif isinstance(index, numbers.Integral):
# Only basic handles numeric indices and it removes the axis so
# only basic gets the index
basic.append(index)
elif isinstance(index, slice) and index == slice(None):
# Empty slices are handled by both basic and advanced indexing
basic.append(slice(None))
advanced.append(slice(None))
elif isinstance(index, slice):
# Non-empty slice are only handled by basic indexing
basic.append(index)
advanced.append(slice(None))
else:
# For anything else, we defer to advanced indexing for the type
# checking
basic.append(slice(None))
advanced.append(index)

return tuple(newaxes), tuple(basic), tuple(advanced)


def _TypeError_to_NotImplemented(f):
"""Convert any TypeErrors raised by the given function into NotImplemented"""
@functools.wraps(f)
Expand Down
9 changes: 9 additions & 0 deletions releasenotes/notes/feature-newaxis-b42829ddab7379a7.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
features:
- |
Support ``numpy.newaxis`` and ``...`` while indexing array symbols.
issues:
- |
In some cases, combined indexing operations on high dimensional arrays will
not match the behavior of NumPy.
See `#465 <https://github.com/dwavesystems/dwave-optimization/issues/465>`_.
51 changes: 0 additions & 51 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,57 +186,6 @@ class UnknownType():
numerator, denominator = two_over_x.iter_predecessors()
self.assertEqual(denominator.id(), x_.id())

class IndexTester:
def __init__(self, array):
self.array = array

def __getitem__(self, index):
if not isinstance(index, tuple):
return self[(index,)]
i0, i1 = dwave.optimization._model._split_indices(index)
np.testing.assert_array_equal(self.array[index], self.array[i0][i1])

def test_split_indices(self):
a1d = np.arange(5)
a2d = np.arange(5 * 6).reshape(5, 6)
a4d = np.arange(5 * 6 * 7 * 8).reshape(5, 6, 7, 8)

for arr in [a1d, a2d, a4d]:
test = self.IndexTester(arr)

test[:]
test[0]
test[np.asarray([0, 1, 2])]
test[1:]
test[1:4:2]

if arr.ndim < 2:
continue

test[:, :]
test[0, 1]
test[:2, :]
test[:, 1::2]
test[0, :]
test[:, 3]
test[::2, 2]
test[np.asarray([0, 2, 1]), np.asarray([0, 0, 0])]
test[np.asarray([0, 2, 1]), :]
test[np.asarray([0, 2, 1]), 3]
test[:, np.asarray([0, 2, 1])]
test[3, np.asarray([0, 2, 1])]

if arr.ndim < 3:
continue

test[:, :, :, :]
test[0, 1, 2, 4]
test[::2, 0, ::2, :]

# two different types of combined indexing
test[np.asarray([0, 2, 1]), :, np.asarray([0, 0, 0]), np.asarray([0, 0, 0])]
test[np.asarray([0, 2, 1]), np.asarray([0, 0, 0]), np.asarray([0, 0, 0]), :]


class TestModel(unittest.TestCase):
def test(self):
Expand Down
Loading