From ffe385ba476e1acc84f38aad344bdee5eb3bbffe Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Fri, 25 Jul 2025 15:25:13 +0300 Subject: [PATCH 1/2] feat: deprecate automatic flattening --- modepy/shapes.py | 98 ++++++++++++++++++++------------------- modepy/spaces.py | 66 ++++++++++++++------------ modepy/test/test_tools.py | 3 -- 3 files changed, 87 insertions(+), 80 deletions(-) diff --git a/modepy/shapes.py b/modepy/shapes.py index 531d5d7..45f1948 100644 --- a/modepy/shapes.py +++ b/modepy/shapes.py @@ -221,7 +221,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from functools import partial, singledispatch -from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload +from typing import TYPE_CHECKING, TypeVar, cast import numpy as np from typing_extensions import override @@ -343,25 +343,23 @@ class TensorProductShape(Shape): bases: tuple[Shape, ...] """A :class:`tuple` of base shapes that form the tensor product.""" - @overload - # pyright-ignore: they overlap, can't be helped. - def __new__(cls, bases: tuple[ShapeT]) -> ShapeT: ... # pyright: ignore[reportOverlappingOverload] - - @overload - def __new__(cls, bases: tuple[Shape, ...]) -> TensorProductShape: ... + def __init__(self, + bases: tuple[Shape, ...], + flatten: bool = True) -> None: + # flatten input shapes + if flatten: + from warnings import warn - def __new__(cls, bases: tuple[Shape, ...]) -> Shape: - if len(bases) == 1: - return bases[0] - else: - return Shape.__new__(cls) + if any(isinstance(s, TensorProductShape) for s in bases): + warn(f"Automatic flattening in the '{type(self).__name__}' constructor " + "is deprecated and will be set to False in 2026. Use " + f"'{type(self).__name__}.flatten()' instead to manually flatten.", + DeprecationWarning, stacklevel=2) - def __init__(self, bases: tuple[Shape, ...]) -> None: - # flatten input shapes - bases = sum(( - s.bases if isinstance(s, TensorProductShape) else (s,) - for s in bases - ), ()) + bases = sum(( + s.bases if isinstance(s, TensorProductShape) else (s,) + for s in bases + ), ()) nsegments = len([s for s in bases if s.dim == 1]) if nsegments < len(bases) - 1: @@ -384,7 +382,27 @@ def nvertices(self) -> int: def nfaces(self) -> int: # FIXME: this obviously only works for `shape x segment x segment x ...` *segments, shape = sorted(self.bases, key=lambda s: s.dim) - return shape.nfaces + len(segments) * segments[0].nfaces + nfaces = segments[0].nfaces if segments else 0 + return shape.nfaces + len(segments) * nfaces + + def flatten(self) -> Shape: + """Flattens a tensor product shape into its component pieces. + + This function recursively removes tensor product shapes from + :attr:`TensorProductShape.bases`. If only a single shape remains in the + tensor product, then it is returned directly. + """ + bases: list[Shape] = [] + for s in self.bases: + if isinstance(s, TensorProductShape): + s = s.flatten() + + bases.extend(s.bases if isinstance(s, TensorProductShape) else [s]) + + if len(bases) == 1: + return bases[0] + else: + return TensorProductShape(tuple(bases)) @unit_vertices_for_shape.register(TensorProductShape) @@ -470,24 +488,12 @@ def faces_for_simplex(shape: Simplex) -> tuple[Face, ...]: # {{{ hypercube -@dataclass(frozen=True) +@dataclass(frozen=True, init=False) class Hypercube(TensorProductShape): """An n-dimensional hypercube (line, square, hexahedron, etc.).""" - @overload - def __new__(cls, dim: Literal[1]) -> Simplex: ... - - @overload - def __new__(cls, dim: int) -> Hypercube: ... - - def __new__(cls, dim: int) -> Shape: - if dim == 1: - return Simplex(1) - else: - return Shape.__new__(cls) - def __init__(self, dim: int) -> None: - super().__init__((Simplex(1),) * dim) + super().__init__((Simplex(1),) * dim, flatten=False) def __getnewargs__(self): # NOTE: ensures Hypercube is picklable @@ -496,21 +502,19 @@ def __getnewargs__(self): @dataclass(frozen=True, init=False) class _HypercubeFace(Hypercube, Face): - @overload - def __new__(cls, dim: Literal[1], **kwargs: Any) -> _SimplexFace: ... - - @overload - def __new__(cls, dim: int, **kwargs: Any) -> _HypercubeFace: ... - - def __new__(cls, dim: int, **kwargs: Any) -> Face: - if dim == 1: - return _SimplexFace(dim=1, **kwargs) - else: - return Shape.__new__(cls) - - def __init__(self, dim: int, **kwargs: Any) -> None: + def __init__(self, + dim: int, + volume_shape: Shape, + face_index: int, + volume_vertex_indices: tuple[int, ...], + map_to_volume: Callable[[ArrayF], ArrayF], + ) -> None: Hypercube.__init__(self, dim) - Face.__init__(self, **kwargs) + Face.__init__(self, + volume_shape=volume_shape, + face_index=face_index, + volume_vertex_indices=volume_vertex_indices, + map_to_volume=map_to_volume) def _hypercube_face_to_vol_map( diff --git a/modepy/spaces.py b/modepy/spaces.py index e328c54..c67c896 100644 --- a/modepy/spaces.py +++ b/modepy/spaces.py @@ -51,7 +51,7 @@ from abc import ABC, abstractmethod from functools import singledispatch from numbers import Number -from typing import Literal, TypeVar, overload +from typing import TypeVar import numpy as np from typing_extensions import override @@ -121,24 +121,24 @@ class TensorProductSpace(FunctionSpace): bases: tuple[FunctionSpace, ...] """A :class:`tuple` of the base spaces that take part in the tensor product.""" - @overload - # pyright-ignore: they overlap, can't be helped. - def __new__(cls, bases: tuple[FunctionSpaceT]) -> FunctionSpaceT: ... # pyright: ignore[reportOverlappingOverload] + def __init__(self, + bases: tuple[FunctionSpace, ...], *, + flatten: bool = True) -> None: + if flatten: + from warnings import warn - @overload - def __new__(cls, bases: tuple[FunctionSpace, ...]) -> TensorProductSpace: ... + if any(isinstance(s, TensorProductSpace) for s in bases): + warn(f"Automatic flattening in the '{type(self).__name__}' constructor " + "is deprecated and will be set to False in 2026. Use " + f"'{type(self).__name__}.flatten()' instead to manually flatten.", + DeprecationWarning, stacklevel=2) - def __new__(cls, bases: tuple[FunctionSpace, ...]) -> FunctionSpace: - if len(bases) == 1: - return bases[0] - else: - return FunctionSpace.__new__(cls) + bases = sum(( + space.bases if isinstance(space, TensorProductSpace) else (space,) + for space in bases + ), ()) - def __init__(self, bases: tuple[FunctionSpace, ...]) -> None: - self.bases = sum(( - space.bases if isinstance(space, TensorProductSpace) else (space,) - for space in bases - ), ()) + self.bases = bases def __getnewargs__(self): # Ensures TensorProductSpace is picklable @@ -172,6 +172,25 @@ def __repr__(self) -> str: f"bases={self.bases!r}" ")") + def flatten(self) -> FunctionSpace: + """Flattens a tensor product space into its component pieces. + + This function recursively removes tensor product spaces from + :attr:`TensorProductSpace.bases`. If only a single spaces remains in the + tensor product, then it is returned directly. + """ + bases: list[FunctionSpace] = [] + for s in self.bases: + if isinstance(s, TensorProductSpace): + s = s.flatten() + + bases.extend(s.bases if isinstance(s, TensorProductSpace) else [s]) + + if len(bases) == 1: + return bases[0] + else: + return TensorProductSpace(tuple(bases)) + @space_for_shape.register(TensorProductShape) def space_for_tensor_product_shape( @@ -256,21 +275,8 @@ class QN(TensorProductSpace): \left \{\prod_{i=1}^d x_i^{n_i}:\max n_i\le N\right\}. """ - @overload - # pyright-ignore: they overlap, can't be helped. - def __new__(cls, spatial_dim: Literal[1], order: int) -> PN: ... # pyright: ignore[reportOverlappingOverload] - - @overload - def __new__(cls, spatial_dim: int, order: int) -> QN: ... - - def __new__(cls, spatial_dim: int, order: int) -> FunctionSpace: - if spatial_dim == 1: - return PN(spatial_dim, order) - else: - return FunctionSpace.__new__(cls) - def __init__(self, spatial_dim: int, order: int) -> None: - super().__init__((PN(1, order),) * spatial_dim) + super().__init__((PN(1, order),) * spatial_dim, flatten=False) @property @override diff --git a/modepy/test/test_tools.py b/modepy/test/test_tools.py index 85470b1..71ecd88 100644 --- a/modepy/test/test_tools.py +++ b/modepy/test/test_tools.py @@ -466,9 +466,6 @@ def test_tensor_product_shapes(): (mp.TensorProductShape((mp.Simplex(1), mp.Simplex(2))), 3, 6, 5), ] - assert isinstance(mp.Hypercube(1), mp.Simplex) - assert isinstance(mp.TensorProductShape((mp.Simplex(2),)), mp.Simplex) - for shape, dim, nvertices, nfaces in shapes: assert shape.dim == dim assert shape.nvertices == nvertices From a27cd454f6c50a2ff938636b562d10bb552b7085 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Fri, 25 Jul 2025 15:25:33 +0300 Subject: [PATCH 2/2] chore: update baseline --- .basedpyright/baseline.json | 112 ------------------------------------ 1 file changed, 112 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index f4c0af4..05ac5d3 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -5331,118 +5331,6 @@ "lineCount": 1 } }, - { - "code": "reportOverlappingOverload", - "range": { - "startColumn": 8, - "endColumn": 15, - "lineCount": 1 - } - }, - { - "code": "reportOverlappingOverload", - "range": { - "startColumn": 8, - "endColumn": 15, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 40, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 33, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 33, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 41, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 41, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 41, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 41, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 35, - "endColumn": 41, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 30, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 30, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 30, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 30, - "endColumn": 36, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": {