Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 0 additions & 112 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -5331,118 +5331,6 @@
"lineCount": 1
}
},
{
"code": "reportOverlappingOverload",
"range": {
"startColumn": 8,
"endColumn": 15,
"lineCount": 1
}
},
{
"code": "reportOverlappingOverload",
"range": {
"startColumn": 8,
"endColumn": 15,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 40,
"endColumn": 46,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 33,
"endColumn": 39,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 33,
"endColumn": 39,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 41,
"endColumn": 47,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 41,
"endColumn": 47,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 41,
"endColumn": 47,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 41,
"endColumn": 47,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 35,
"endColumn": 41,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 30,
"endColumn": 36,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 30,
"endColumn": 36,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 30,
"endColumn": 36,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 30,
"endColumn": 36,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand Down
98 changes: 51 additions & 47 deletions modepy/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial, singledispatch
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload
from typing import TYPE_CHECKING, TypeVar, cast

import numpy as np
from typing_extensions import override
Expand Down Expand Up @@ -343,25 +343,23 @@ class TensorProductShape(Shape):
bases: tuple[Shape, ...]
"""A :class:`tuple` of base shapes that form the tensor product."""

@overload
# pyright-ignore: they overlap, can't be helped.
def __new__(cls, bases: tuple[ShapeT]) -> ShapeT: ... # pyright: ignore[reportOverlappingOverload]

@overload
def __new__(cls, bases: tuple[Shape, ...]) -> TensorProductShape: ...
def __init__(self,
bases: tuple[Shape, ...],
flatten: bool = True) -> None:
# flatten input shapes
if flatten:
from warnings import warn

def __new__(cls, bases: tuple[Shape, ...]) -> Shape:
if len(bases) == 1:
return bases[0]
else:
return Shape.__new__(cls)
if any(isinstance(s, TensorProductShape) for s in bases):
warn(f"Automatic flattening in the '{type(self).__name__}' constructor "
"is deprecated and will be set to False in 2026. Use "
f"'{type(self).__name__}.flatten()' instead to manually flatten.",
DeprecationWarning, stacklevel=2)

def __init__(self, bases: tuple[Shape, ...]) -> None:
# flatten input shapes
bases = sum((
s.bases if isinstance(s, TensorProductShape) else (s,)
for s in bases
), ())
bases = sum((
s.bases if isinstance(s, TensorProductShape) else (s,)
for s in bases
), ())

nsegments = len([s for s in bases if s.dim == 1])
if nsegments < len(bases) - 1:
Expand All @@ -384,7 +382,27 @@ def nvertices(self) -> int:
def nfaces(self) -> int:
# FIXME: this obviously only works for `shape x segment x segment x ...`
*segments, shape = sorted(self.bases, key=lambda s: s.dim)
return shape.nfaces + len(segments) * segments[0].nfaces
nfaces = segments[0].nfaces if segments else 0
return shape.nfaces + len(segments) * nfaces

def flatten(self) -> Shape:
"""Flattens a tensor product shape into its component pieces.

This function recursively removes tensor product shapes from
:attr:`TensorProductShape.bases`. If only a single shape remains in the
tensor product, then it is returned directly.
"""
bases: list[Shape] = []
for s in self.bases:
if isinstance(s, TensorProductShape):
s = s.flatten()

bases.extend(s.bases if isinstance(s, TensorProductShape) else [s])

if len(bases) == 1:
return bases[0]
else:
return TensorProductShape(tuple(bases))


@unit_vertices_for_shape.register(TensorProductShape)
Expand Down Expand Up @@ -470,24 +488,12 @@ def faces_for_simplex(shape: Simplex) -> tuple[Face, ...]:

# {{{ hypercube

@dataclass(frozen=True)
@dataclass(frozen=True, init=False)
class Hypercube(TensorProductShape):
"""An n-dimensional hypercube (line, square, hexahedron, etc.)."""

@overload
def __new__(cls, dim: Literal[1]) -> Simplex: ...

@overload
def __new__(cls, dim: int) -> Hypercube: ...

def __new__(cls, dim: int) -> Shape:
if dim == 1:
return Simplex(1)
else:
return Shape.__new__(cls)

def __init__(self, dim: int) -> None:
super().__init__((Simplex(1),) * dim)
super().__init__((Simplex(1),) * dim, flatten=False)

def __getnewargs__(self):
# NOTE: ensures Hypercube is picklable
Expand All @@ -496,21 +502,19 @@ def __getnewargs__(self):

@dataclass(frozen=True, init=False)
class _HypercubeFace(Hypercube, Face):
@overload
def __new__(cls, dim: Literal[1], **kwargs: Any) -> _SimplexFace: ...

@overload
def __new__(cls, dim: int, **kwargs: Any) -> _HypercubeFace: ...

def __new__(cls, dim: int, **kwargs: Any) -> Face:
if dim == 1:
return _SimplexFace(dim=1, **kwargs)
else:
return Shape.__new__(cls)

def __init__(self, dim: int, **kwargs: Any) -> None:
def __init__(self,
dim: int,
volume_shape: Shape,
face_index: int,
volume_vertex_indices: tuple[int, ...],
map_to_volume: Callable[[ArrayF], ArrayF],
) -> None:
Hypercube.__init__(self, dim)
Face.__init__(self, **kwargs)
Face.__init__(self,
volume_shape=volume_shape,
face_index=face_index,
volume_vertex_indices=volume_vertex_indices,
map_to_volume=map_to_volume)


def _hypercube_face_to_vol_map(
Expand Down
66 changes: 36 additions & 30 deletions modepy/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from abc import ABC, abstractmethod
from functools import singledispatch
from numbers import Number
from typing import Literal, TypeVar, overload
from typing import TypeVar

import numpy as np
from typing_extensions import override
Expand Down Expand Up @@ -121,24 +121,24 @@ class TensorProductSpace(FunctionSpace):
bases: tuple[FunctionSpace, ...]
"""A :class:`tuple` of the base spaces that take part in the tensor product."""

@overload
# pyright-ignore: they overlap, can't be helped.
def __new__(cls, bases: tuple[FunctionSpaceT]) -> FunctionSpaceT: ... # pyright: ignore[reportOverlappingOverload]
def __init__(self,
bases: tuple[FunctionSpace, ...], *,
flatten: bool = True) -> None:
if flatten:
from warnings import warn

@overload
def __new__(cls, bases: tuple[FunctionSpace, ...]) -> TensorProductSpace: ...
if any(isinstance(s, TensorProductSpace) for s in bases):
warn(f"Automatic flattening in the '{type(self).__name__}' constructor "
"is deprecated and will be set to False in 2026. Use "
f"'{type(self).__name__}.flatten()' instead to manually flatten.",
DeprecationWarning, stacklevel=2)

def __new__(cls, bases: tuple[FunctionSpace, ...]) -> FunctionSpace:
if len(bases) == 1:
return bases[0]
else:
return FunctionSpace.__new__(cls)
bases = sum((
space.bases if isinstance(space, TensorProductSpace) else (space,)
for space in bases
), ())

def __init__(self, bases: tuple[FunctionSpace, ...]) -> None:
self.bases = sum((
space.bases if isinstance(space, TensorProductSpace) else (space,)
for space in bases
), ())
self.bases = bases

def __getnewargs__(self):
# Ensures TensorProductSpace is picklable
Expand Down Expand Up @@ -172,6 +172,25 @@ def __repr__(self) -> str:
f"bases={self.bases!r}"
")")

def flatten(self) -> FunctionSpace:
"""Flattens a tensor product space into its component pieces.

This function recursively removes tensor product spaces from
:attr:`TensorProductSpace.bases`. If only a single spaces remains in the
tensor product, then it is returned directly.
"""
bases: list[FunctionSpace] = []
for s in self.bases:
if isinstance(s, TensorProductSpace):
s = s.flatten()

bases.extend(s.bases if isinstance(s, TensorProductSpace) else [s])

if len(bases) == 1:
return bases[0]
else:
return TensorProductSpace(tuple(bases))


@space_for_shape.register(TensorProductShape)
def space_for_tensor_product_shape(
Expand Down Expand Up @@ -256,21 +275,8 @@ class QN(TensorProductSpace):
\left \{\prod_{i=1}^d x_i^{n_i}:\max n_i\le N\right\}.
"""

@overload
# pyright-ignore: they overlap, can't be helped.
def __new__(cls, spatial_dim: Literal[1], order: int) -> PN: ... # pyright: ignore[reportOverlappingOverload]

@overload
def __new__(cls, spatial_dim: int, order: int) -> QN: ...

def __new__(cls, spatial_dim: int, order: int) -> FunctionSpace:
if spatial_dim == 1:
return PN(spatial_dim, order)
else:
return FunctionSpace.__new__(cls)

def __init__(self, spatial_dim: int, order: int) -> None:
super().__init__((PN(1, order),) * spatial_dim)
super().__init__((PN(1, order),) * spatial_dim, flatten=False)

@property
@override
Expand Down
3 changes: 0 additions & 3 deletions modepy/test/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@


@pytest.mark.parametrize("space_nh", [
mp.TensorProductSpace((mp.QN(1, 3),)),

Check warning on line 179 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 179 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 179 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 179 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.x

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 179 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.x

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 179 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.x

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.
mp.TensorProductSpace((mp.QN(1, 3), mp.QN(1, 5))),

Check warning on line 180 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 180 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 180 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 180 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.x

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 180 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.x

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 180 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.x

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.
mp.TensorProductSpace((mp.QN(1, 3), mp.QN(1, 5), mp.QN(1, 2))),

Check warning on line 181 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 181 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 181 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.x

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 181 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.x

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.
mp.TensorProductSpace((mp.QN(2, 3), mp.QN(1, 2))),

Check warning on line 182 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 182 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 182 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.x

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.

Check warning on line 182 in modepy/test/test_tools.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.x

Automatic flattening in the 'TensorProductSpace' constructor is deprecated and will be set to False in 2026. Use 'TensorProductSpace.flatten()' instead to manually flatten.
])
def test_non_homogeneous_tensor_product_resampling(space_nh):
shape = mp.Hypercube(space_nh.spatial_dim)
Expand Down Expand Up @@ -466,9 +466,6 @@
(mp.TensorProductShape((mp.Simplex(1), mp.Simplex(2))), 3, 6, 5),
]

assert isinstance(mp.Hypercube(1), mp.Simplex)
assert isinstance(mp.TensorProductShape((mp.Simplex(2),)), mp.Simplex)

for shape, dim, nvertices, nfaces in shapes:
assert shape.dim == dim
assert shape.nvertices == nvertices
Expand Down
Loading