Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37,180 changes: 37,180 additions & 0 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,20 @@ jobs:
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-pylint.sh
. ./prepare-and-run-pylint.sh "$(basename $GITHUB_REPOSITORY)" examples/*.py test/test_*.py

mypy:
name: Mypy
basedpyright:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: "Main Script"
run: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
# pyright is happier with missing packages than installed, untyped ones
sed -i /oct2py/d .test-conda-env-py3.yml
sed -i /h5py/d .test-conda-env-py3.yml
build_py_project_in_conda_env
python -m pip install mypy
./run-mypy.sh
python -m pip install basedpyright
basedpyright

pytest3:
name: Pytest Conda Py3
Expand Down
13 changes: 0 additions & 13 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,6 @@ Pylint:
except:
- tags

Mypy:
script: |
EXTRA_INSTALL="Cython mpi4py"
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_venv
python -m pip install mypy
./run-mypy.sh
tags:
- python3
except:
- tags

Downstream:
parallel:
matrix:
Expand Down
33 changes: 33 additions & 0 deletions doc/misc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,36 @@ AK also gratefully acknowledges a hardware gift from Nvidia Corporation.

The views and opinions expressed herein do not necessarily reflect those of the
funding agencies.

References
----------

.. class:: ArrayContext

See :class:`arraycontext.ArrayContext`.

.. class:: _Mesh

See :class:`meshmode.mesh.Mesh`.

.. class:: _MeshElementGroup

See :class:`meshmode.mesh.MeshElementGroup`.

.. class:: _DOFArray

See :class:`meshmode.dof_array.DOFArray`.

.. class:: DTypeLike

See :data:`numpy.typing.DTypeLike`.

.. currentmodule:: np

.. class:: dtype

See :class:`numpy.dtype`.

.. class:: complexfloating

See :class:`numpy.complexfloating`.
3 changes: 3 additions & 0 deletions meshmode/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations


__copyright__ = "Copyright (C) 2014 Andreas Kloeckner"

__license__ = """
Expand Down
26 changes: 21 additions & 5 deletions meshmode/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
.. autoclass:: PyOpenCLArrayContext
.. autoclass:: PytatoPyOpenCLArrayContext
"""
from __future__ import annotations


__copyright__ = "Copyright (C) 2020 Andreas Kloeckner"

Expand All @@ -25,8 +27,11 @@
THE SOFTWARE.
"""

from typing import TYPE_CHECKING
from warnings import warn

from typing_extensions import override

from arraycontext import (
PyOpenCLArrayContext as PyOpenCLArrayContextBase,
PytatoPyOpenCLArrayContext as PytatoPyOpenCLArrayContextBase,
Expand All @@ -38,6 +43,11 @@
)


if TYPE_CHECKING:
import pytato as pt_typ
from loopy import TranslationUnit


def thaw(actx, ary):
warn("meshmode.array_context.thaw is deprecated. Use arraycontext.thaw instead. "
"WARNING: The argument order is reversed between these two functions. "
Expand All @@ -49,7 +59,7 @@ def thaw(actx, ary):

# {{{ kernel transform function

def _transform_loopy_inner(t_unit):
def _transform_loopy_inner(t_unit: TranslationUnit):
import loopy as lp
from arraycontext.transform_metadata import ElementwiseMapKernelTag
from pymbolic.primitives import Subscript, Variable
Expand Down Expand Up @@ -136,6 +146,7 @@ def _transform_loopy_inner(t_unit):
"FirstAxisIsElementsTag-tagged kernels must be "
"subscripts")

assert isinstance(assignee.aggregate, Variable)
if assignee.aggregate.name not in first_axis_el_args:
continue

Expand Down Expand Up @@ -207,7 +218,8 @@ class PyOpenCLArrayContext(PyOpenCLArrayContextBase):
See :mod:`meshmode.transform_metadata` for relevant metadata.
"""

def transform_loopy_program(self, t_unit):
@override
def transform_loopy_program(self, t_unit: TranslationUnit):
default_ep = t_unit.default_entrypoint
options = default_ep.options
if not (options.return_dict and options.no_numpy):
Expand Down Expand Up @@ -238,15 +250,18 @@ def transform_loopy_program(self, t_unit):
# {{{ pytato pyopencl array context subclass

class PytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContextBase):
def transform_dag(self, dag):
@override
def transform_dag(self, dag: pt_typ.DictOfNamedArrays) -> pt_typ.DictOfNamedArrays:
dag = super().transform_dag(dag)

# {{{ /!\ Remove tags from NamedArrays
# See <https://www.github.com/inducer/pytato/issues/195>

import pytato as pt

def untag_loopy_call_results(expr):
def untag_loopy_call_results(
expr: pt.Array | pt.AbstractResultWithNamedArrays
) -> pt.Array | pt.AbstractResultWithNamedArrays:
if isinstance(expr, pt.NamedArray):
return expr.copy(tags=frozenset(),
axes=(pt.Axis(frozenset()),)*expr.ndim)
Expand All @@ -259,7 +274,8 @@ def untag_loopy_call_results(expr):

return dag

def transform_loopy_program(self, t_unit):
@override
def transform_loopy_program(self, t_unit: TranslationUnit):
# FIXME: Do not parallelize for now.
return t_unit

Expand Down
Loading
Loading