Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
498 changes: 22 additions & 476 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
# meshmode
"Discretization": "class:meshmode.discretization.Discretization",
"DOFArray": "class:meshmode.dof_array.DOFArray",
"ElementGroupFactory": "class:meshmode.discretization.ElementGroupFactory",
# boxtree
"FromSepSmallerCrit": "obj:boxtree.traversal.FromSepSmallerCrit",
"TimingResult": "class:boxtree.timing.TimingResult",
Expand Down
21 changes: 10 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,20 @@ extend-select = [
"NPY", # numpy
"Q", # flake8-quotes
"RUF", # ruff
"TC", # flake8-type-checking
"UP", # pyupgrade
"W", # pycodestyle
"TC",
]
extend-ignore = [
"C90", # McCabe complexity
"E226", # missing whitespace around arithmetic operator
"E402", # module level import not at the top of file
"N802", # function name should be lowercase
"N803", # argument name should be lowercase
"N806", # variable name should be lowercase
"RUF012", # ClassVar for mutable class attributes
"RUF022", # __all__ is not sorted
"UP031", # use f-strings instead of %
"UP032", # use f-strings instead of .format
"C90", # McCabe complexity
"E226", # missing whitespace around arithmetic operator
"E402", # module level import not at the top of file
"N802", # function name should be lowercase
"N803", # argument name should be lowercase
"N806", # variable name should be lowercase
"RUF067", # non-empty-init-module
"UP031", # use f-strings instead of %
"UP032", # use f-strings instead of .format
]
exclude = [
"experiments/*.py",
Expand Down
2 changes: 1 addition & 1 deletion pytential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,4 @@ def norm(
raise ValueError(f"unsupported norm order: {p}")


__all__ = ["sym", "bind", "GeometryCollection"]
__all__ = ("GeometryCollection", "bind", "integral", "norm", "sym")
18 changes: 14 additions & 4 deletions pytential/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
THE SOFTWARE.
"""

from typing import TYPE_CHECKING, ClassVar

from typing_extensions import override

from arraycontext.pytest import (
_PytestPyOpenCLArrayContextFactoryWithClass,
register_pytest_array_context_factory,
Expand All @@ -33,6 +37,10 @@
)


if TYPE_CHECKING:
import loopy as lp
from arraycontext import ArrayContext

__doc__ = """
.. autoclass:: PyOpenCLArrayContext
"""
Expand All @@ -41,7 +49,8 @@
# {{{ PyOpenCLArrayContext

class PyOpenCLArrayContext(PyOpenCLArrayContextBase):
def transform_loopy_program(self, t_unit):
@override
def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
kernel = t_unit.default_entrypoint
options = kernel.options

Expand All @@ -68,7 +77,7 @@ def transform_loopy_program(self, t_unit):

# {{{ pytest

def _acf():
def _acf() -> PyOpenCLArrayContext: # pyright: ignore[reportUnusedFunction]
import pyopencl as cl
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
Expand All @@ -78,9 +87,10 @@ def _acf():

class PytestPyOpenCLArrayContextFactory(
_PytestPyOpenCLArrayContextFactoryWithClass):
actx_class = PyOpenCLArrayContext
actx_class: ClassVar[ArrayContext] = PyOpenCLArrayContext

def __call__(self):
@override
def __call__(self) -> ArrayContext:
# NOTE: prevent any cache explosions during testing!
from sympy.core.cache import clear_cache
clear_cache()
Expand Down
6 changes: 4 additions & 2 deletions pytential/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@


__all__ = (
"IndexList", "TargetAndSourceClusterList",
"make_index_list", "make_index_cluster_cartesian_product",
"IndexList",
"TargetAndSourceClusterList",
"interp_decomp",
"make_index_cluster_cartesian_product",
"make_index_list",
)
14 changes: 7 additions & 7 deletions pytential/qbx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,12 +1067,12 @@ def get_flat_strengths_from_densities(
# }}}


__all__ = [
"QBXLayerPotentialSource",
"QBXTargetAssociationFailedError",
"QBXDefaultExpansionFactory",
"ExpansionFactoryBase",
"LocalExpansionBase",
]
__all__ = (
"ExpansionFactoryBase",
"LocalExpansionBase",
"QBXDefaultExpansionFactory",
"QBXLayerPotentialSource",
"QBXTargetAssociationFailedError",
)

# vim: fdm=marker
4 changes: 2 additions & 2 deletions pytential/qbx/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import logging
from abc import abstractmethod
from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar

import numpy as np
from mako.template import Template
Expand Down Expand Up @@ -515,7 +515,7 @@ def get_unit_calibration_params():

return calibration_params

_QBX_STAGE_TO_CALIBRATION_PARAMETER = {
_QBX_STAGE_TO_CALIBRATION_PARAMETER: ClassVar[dict[str, str]] = {
"form_global_qbx_locals": "c_p2qbxl",
"translate_box_multipoles_to_qbx_local": "c_m2qbxl",
"translate_box_local_to_qbx_local": "c_l2qbxl",
Expand Down
42 changes: 29 additions & 13 deletions pytential/qbx/refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"""

import logging
from typing import cast
from typing import TYPE_CHECKING, Any, cast

import numpy as np

Expand All @@ -42,11 +42,18 @@
from pytential.qbx.utils import (
QBX_TREE_C_PREAMBLE,
QBX_TREE_MAKO_DEFS,
TreeCodeContainer,
TreeCodeContainerMixin,
TreeWranglerBase,
)


if TYPE_CHECKING:
from meshmode.discretization import ElementGroupFactory

from pytential.collection import GeometryCollection
from pytential.symbolic.dof_desc import DiscretizationStage

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -224,17 +231,24 @@
# {{{ code container

class RefinerCodeContainer(TreeCodeContainerMixin):
array_context: PyOpenCLArrayContext
tree_code_container: TreeCodeContainer

def __init__(self, actx: PyOpenCLArrayContext):
def __init__(self, actx: PyOpenCLArrayContext) -> None:
self.array_context = actx

from pytential.qbx.utils import tree_code_container
self.tree_code_container = tree_code_container(actx)

@memoize_method
def expansion_disk_undisturbed_by_sources_checker(
self, dimensions, coord_dtype, box_id_dtype, peer_list_idx_dtype,
particle_id_dtype, max_levels):
self,
dimensions: int,
coord_dtype: np.dtype[np.floating[Any]],
box_id_dtype: np.dtype[np.integer[Any]],
peer_list_idx_dtype: np.dtype[np.integer[Any]],
particle_id_dtype: np.dtype[np.integer[Any]],
max_levels: int) -> int:
return EXPANSION_DISK_UNDISTURBED_BY_SOURCES_CHECKER.generate(
self.array_context.context,
dimensions, coord_dtype, box_id_dtype, peer_list_idx_dtype,
Expand Down Expand Up @@ -935,15 +949,17 @@ def get_from_cache(from_ds, to_ds):

# {{{ refine_geometry_collection

def refine_geometry_collection(places,
group_factory=None,
refine_discr_stage=None,
kernel_length_scale=None,
force_stage2_uniform_refinement_rounds=None,
scaled_max_curvature_threshold=None,
expansion_disturbance_tolerance=None,
maxiter=None,
debug=None, visualize=False):
def refine_geometry_collection(
places: GeometryCollection,
group_factory: ElementGroupFactory | None = None,
refine_discr_stage: DiscretizationStage | None = None,
kernel_length_scale: float | np.floating[Any] | None = None,
force_stage2_uniform_refinement_rounds: int | None = None,
scaled_max_curvature_threshold: float | None = None,
expansion_disturbance_tolerance: float | None = None,
maxiter: int | None = None,
debug: bool | None = None,
visualize: bool = False) -> GeometryCollection:
"""Entry point for refining all the
:class:`~pytential.qbx.QBXLayerPotentialSource` in the given collection.
The :class:`~pytential.collection.GeometryCollection` performs
Expand Down
2 changes: 1 addition & 1 deletion pytential/symbolic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@
.. autofunction:: pretty
"""

__all__ = (
__all__ = ( # noqa: RUF022
# re-export from pymbolic
"Variable",
"cse",
Expand Down
62 changes: 46 additions & 16 deletions test/extra_curve_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,30 @@
THE SOFTWARE.
"""

from typing import Any, TypeAlias

import numpy as np
import numpy.linalg as la
from typing_extensions import override


Array1D: TypeAlias = np.ndarray[tuple[int], np.dtype[np.floating]]
Array2D: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.floating]]


class Curve:
def plot(self, npoints=50):
def plot(self, npoints: int = 50) -> None:
import matplotlib.pyplot as plt
x, y = self(np.linspace(0, 1, npoints))

plt.plot(x, y, marker=".", lw=0)
plt.axis("equal")
plt.show()

def __add__(self, other):
def __add__(self, other: Curve) -> Curve:
return CompositeCurve(self, other)

def __call__(self, ts):
def __call__(self, ts: Array1D) -> Array2D:
raise NotImplementedError


Expand All @@ -47,33 +55,41 @@ class CompositeCurve(Curve):
Parametrization of two or more curves combined.
"""

def __init__(self, *objs):
curves = []
curves: tuple[Curve, ...]

def __init__(self, *objs: Curve) -> None:
curves: list[Curve] = []
for obj in objs:
if isinstance(obj, CompositeCurve):
curves.extend(obj.curves)
else:
curves.append(obj)
self.curves = curves

def __call__(self, ts):
self.curves = tuple(curves)

@override
def __call__(self, ts: Array1D) -> Array2D:
from itertools import pairwise

ranges = np.linspace(0, 1, len(self.curves) + 1)
ts_argsort = np.argsort(ts)
ts_sorted = ts[ts_argsort]
ts_split_points = np.searchsorted(ts_sorted, ranges)

# Make sure the last entry = len(ts), otherwise if ts finishes with a
# trail of 1s, then they won't be forwarded to the last curve.
ts_split_points[-1] = len(ts)
result = []

result: list[Array2D] = []
subranges = [slice(*pair) for pair in pairwise(ts_split_points)]
for curve, subrange, (start, end) in zip(
self.curves, subranges, pairwise(ranges), strict=True):
ts_mapped = (ts_sorted[subrange] - start) / (end - start)
result.append(curve(ts_mapped))

final = np.concatenate(result, axis=-1)
assert len(final[0]) == len(ts)

return final


Expand All @@ -82,11 +98,14 @@ class Segment(Curve):
Represents a line segment.
"""

def __init__(self, start, end):
self.start = np.array(start)
self.end = np.array(end)
def __init__(self,
start: tuple[float, float],
end: tuple[float, float]) -> None:
self.start: Array1D = np.array(start)
self.end: Array1D = np.array(end)

def __call__(self, ts):
@override
def __call__(self, ts: Array1D) -> Array2D:
return (
self.start[:, np.newaxis]
+ ts * (self.end - self.start)[:, np.newaxis])
Expand All @@ -97,7 +116,16 @@ class Arc(Curve):
Represents an arc of a circle.
"""

def __init__(self, start, mid, end):
r: np.floating[Any]
center: Array1D

theta_range: Array1D
theta_increasing: bool

def __init__(self,
start: tuple[float, float],
mid: tuple[float, float],
end: tuple[float, float]) -> None:
"""
:arg start: starting point of the arc
:arg mid: any point along the arc
Expand Down Expand Up @@ -135,15 +163,17 @@ def __init__(self, start, mid, end):
self.theta_range = np.array(sorted([theta_start, theta_end]))
self.theta_increasing = theta_start <= theta_end

def __call__(self, t):
@override
def __call__(self, ts: Array1D) -> Array2D:
if self.theta_increasing:
thetas = (
self.theta_range[0]
+ t * (self.theta_range[1] - self.theta_range[0]))
+ ts * (self.theta_range[1] - self.theta_range[0]))
else:
thetas = (
self.theta_range[1]
- t * (self.theta_range[1] - self.theta_range[0]))
- ts * (self.theta_range[1] - self.theta_range[0]))

val = (self.r * np.exp(1j * thetas)) + self.center
return np.array([val.real, val.imag])

Expand Down
Loading
Loading