diff --git a/.gitignore b/.gitignore index d35386ed39..d6289092fe 100644 --- a/.gitignore +++ b/.gitignore @@ -107,6 +107,7 @@ venv.bak/ /site # DaCe +.dacecache *.sdfg .dace.conf diff --git a/setup.cfg b/setup.cfg index 0b84c6ce00..37f41ecd0a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -183,7 +183,7 @@ lines_after_imports = 2 default_section = THIRDPARTY sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER known_first_party = eve,gtc,gt4py,__externals__,__gtscript__ -known_third_party = attr,black,boltons,cached_property,click,dace,dawn4py,devtools,factory,hypothesis,jinja2,mako,networkx,numpy,packaging,pkg_resources,pybind11,pydantic,pytest,pytest_factoryboy,setuptools,tabulate,tests,typing_extensions,typing_inspect,xxhash +known_third_party = atlas4py,attr,black,boltons,cached_property,click,dace,dawn4py,devtools,factory,hypothesis,jinja2,mako,networkx,numpy,packaging,pkg_resources,pybind11,pydantic,pytest,pytest_factoryboy,setuptools,tabulate,tests,typing_extensions,typing_inspect,xxhash #-- mypy -- [mypy] diff --git a/src/iterator/ARCHITECTURE.md b/src/iterator/ARCHITECTURE.md new file mode 100644 index 0000000000..28daa3728e --- /dev/null +++ b/src/iterator/ARCHITECTURE.md @@ -0,0 +1,67 @@ +# Architecture + +Implements the iterator view as described [here](https://github.com/GridTools/concepts/wiki/Iterator-View). + +## Iterator view program in Python + +A program for the iterator view consists of Python functions decorated with `@fundef` and an entry point, *fencil*, which is a Python function decorated with `@fendef`. The *fencil* must only contain calls to the `closure(...)` function. + +Legal functions much not have side-effects, however, e.g., for debugging purposes, side-effects can be used in embedded execution. + +There are 2 modes of execution: *embedded* (direct execution in Python) and *tracing* (trace function calls -> Eve IR representation -> code generation). +The implementations of *embedded* and *tracing* are decoupled by registering themselves (dependency inversion) in the functions defined in `builtins.py` (contains dispatch functions for all builtins of the model) and `runtime.py` (contains dispatch mechanism for the `fendef` and `fundef` decorators and the `closure(...)` function). + +The builtins dispatcher is implemented in `dispatcher.py`. Implementations are registered with a key (`str`) (currently `tracing` and `embedded`). The active implementation is selected by pushing a key to the dispatcher stack. + +`fundef` returns a wrapper around the function, which dispatches `__call__` to a hook if a predicate is met (used for *tracing*). By default the original function is called (used in *embedded* mode). + +`fendef` return a wrapper that dispatches to a registered function. Should be simplified. *Embedded* registers itself as default, *tracing* registers itself such that it's used if the fencil is called with `backend` keyword argument. + +## Embedded execution + +Embedded execution is implemented in the file `embedded.py`. + +Sketch: +- fields are np arrays with named axes; names are instances of `CartesianAxis` +- in `closure()`, the stencil is executed for each point in the domain, the fields are wrapped in an iterator pointing to the current point of execution. +- `shift()` is lazy (to allow lift implementation), offsets are accumulated in the iterator and only executed when `deref()` is called. +- as described in the design, offsets are abstract; on fencil execution the `offset_provider` keyword argument needs to be specified, which is a dict of `str` to either `CartesianAxis` or `NeighborTableOffsetProvider` +- if `column_axis` keyword argument is specified on fencil execution (or in the fencil decorator), all operations will be done column wise in the give axis; `column_axis` needs to be specified if `scan` is used + +## Tracing + +An iterator view program is traced (implemented in `tracing.py`) and represented in a tree structure defined by the nodes in (`ir.py`). + +Sketch: +- Each builtin returns a `FunctionCall` node representing the builtin. +- Foreach `fundef`, the signature of the wrapped function is extracted, then it is invoked with `Sym` nodes as arguments. +- Expressions involving an `Expr` node (e.g. `Sym`) are converted to appropriate builtin calls, e.g. `4. + Sym(id='foo')` is converted to `FunCall(fun=SymRef(id='plus'), args=...)` +- In appropriate places values are converted to nodes, see `make_node()`. +- Finally the IR tree will be passed to `execute_program()` in `backend_executor.py` which will generator code for the program (and execute, if appropriate). + +## Backends + +See directory `backends/`. + +### Cpptoy + +Generates C++ code in the spirit of https://github.com/GridTools/gridtools/pull/1643. Incomplete, will be adapted to the full C++ prototype. (only code generation) + +### Lisp + +Incomplete. Example for the grammar used in the model design document. (not executable) + +### Embedded + +Generates from the IR an aquivalent Python iterator view program which is then executed in embedded mode (round trip). + +### Double roundtrip + +Generates the Python iterator view program, traces it again, generates again and executes. Ensures that the generated Python code can still be traced. While the original program might be different from the generated program (e.g. `+` will be converted to `plus()` builtin). The programs from the embedded and double roundtrip backends should be identical. + +## Adding a new builtin + +Currently there are 4 places where a new builtin needs to be added +- `builtin.py`: for dispatching to an actual implementation +- `embedded.py` and `tracing.py`: for the respective implementation +- `ir.py`: we check for consistent use of symbols, therefore if a `FunCall` to the new builtin is used, it needs to be available in the symbol table. diff --git a/src/iterator/README.md b/src/iterator/README.md new file mode 100644 index 0000000000..7e59600739 --- /dev/null +++ b/src/iterator/README.md @@ -0,0 +1 @@ +# README diff --git a/src/iterator/__init__.py b/src/iterator/__init__.py new file mode 100644 index 0000000000..1f67e0c238 --- /dev/null +++ b/src/iterator/__init__.py @@ -0,0 +1,19 @@ +from typing import Optional, Union + +from . import builtins, runtime, tracing + + +__all__ = ["builtins", "runtime", "tracing"] + +from packaging.version import LegacyVersion, Version, parse +from pkg_resources import DistributionNotFound, get_distribution + + +try: + __version__: str = get_distribution("gt4py").version +except DistributionNotFound: + __version__ = "X.X.X.unknown" + +__versioninfo__: Optional[Union[LegacyVersion, Version]] = parse(__version__) + +del DistributionNotFound, LegacyVersion, Version, get_distribution, parse diff --git a/src/iterator/atlas_utils.py b/src/iterator/atlas_utils.py new file mode 100644 index 0000000000..153abb6fa6 --- /dev/null +++ b/src/iterator/atlas_utils.py @@ -0,0 +1,33 @@ +# GT4Py New Semantic Model - GridTools Framework +# +# Copyright (c) 2014-2021, ETH Zurich All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. GT4Py +# New Semantic Model is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the Free +# Software Foundation, either version 3 of the License, or any later version. +# See the LICENSE.txt file at the top-level directory of this distribution for +# a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from atlas4py import IrregularConnectivity + + +class AtlasTable: + def __init__(self, atlas_connectivity) -> None: + self.atlas_connectivity = atlas_connectivity + + def __getitem__(self, indices): + primary_index = indices[0] + neigh_index = indices[1] + if isinstance(self.atlas_connectivity, IrregularConnectivity): + if neigh_index < self.atlas_connectivity.cols(primary_index): + return self.atlas_connectivity[primary_index, neigh_index] + else: + return None + else: + if neigh_index < 2: + return self.atlas_connectivity[primary_index, neigh_index] + else: + raise AssertionError() diff --git a/src/iterator/backend_executor.py b/src/iterator/backend_executor.py new file mode 100644 index 0000000000..0f66174c8b --- /dev/null +++ b/src/iterator/backend_executor.py @@ -0,0 +1,21 @@ +from devtools import debug + +from iterator.backends import backend +from iterator.ir import Program + + +def execute_program(prog: Program, *args, **kwargs): + assert "backend" in kwargs + assert len(prog.fencil_definitions) == 1 + + if "debug" in kwargs and kwargs["debug"]: + debug(prog) + + if not len(args) == len(prog.fencil_definitions[0].params): + raise RuntimeError("Incorrect number of arguments") + + if kwargs["backend"] in backend._BACKENDS: + b = backend.get_backend(kwargs["backend"]) + b(prog, *args, **kwargs) + else: + raise RuntimeError(f"Backend {kwargs['backend']} is not registered.") diff --git a/src/iterator/backends/__init__.py b/src/iterator/backends/__init__.py new file mode 100644 index 0000000000..b926168790 --- /dev/null +++ b/src/iterator/backends/__init__.py @@ -0,0 +1 @@ +from . import cpptoy, double_roundtrip, embedded, lisp diff --git a/src/iterator/backends/backend.py b/src/iterator/backends/backend.py new file mode 100644 index 0000000000..5c6fdf3f71 --- /dev/null +++ b/src/iterator/backends/backend.py @@ -0,0 +1,9 @@ +_BACKENDS = {} + + +def register_backend(name, backend): + _BACKENDS[name] = backend + + +def get_backend(name): + return _BACKENDS[name] diff --git a/src/iterator/backends/cpptoy.py b/src/iterator/backends/cpptoy.py new file mode 100644 index 0000000000..cb7e6bbd05 --- /dev/null +++ b/src/iterator/backends/cpptoy.py @@ -0,0 +1,59 @@ +from typing import Any + +from eve import codegen +from eve.codegen import FormatTemplate as as_fmt +from eve.codegen import MakoTemplate as as_mako +from iterator.backends import backend +from iterator.ir import OffsetLiteral +from iterator.transforms import apply_common_transforms + + +class ToyCpp(codegen.TemplatedGenerator): + Sym = as_fmt("{id}") + SymRef = as_fmt("{id}") + IntLiteral = as_fmt("{value}") + FloatLiteral = as_fmt("{value}") + AxisLiteral = as_fmt("{value}") + + def visit_OffsetLiteral(self, node: OffsetLiteral, **kwargs): + return node.value if isinstance(node.value, str) else f"{node.value}_c" + + StringLiteral = as_fmt("{value}") + FunCall = as_fmt("{fun}({','.join(args)})") + Lambda = as_mako( + "[=](${','.join('auto ' + p for p in params)}){return ${expr};}" + ) # TODO capture + StencilClosure = as_mako( + "closure(${domain}, ${stencil}, out(${','.join(outputs)}), ${','.join(inputs)})" + ) + FencilDefinition = as_mako( + """ + auto ${id} = [](${','.join('auto&& ' + p for p in params)}){ + fencil(${'\\n'.join(closures)}); + }; + """ + ) + FunctionDefinition = as_mako( + """ + inline constexpr auto ${id} = [](${','.join('auto ' + p for p in params)}){ + return ${expr}; + }; + """ + ) + Program = as_fmt("{''.join(function_definitions)} {''.join(fencil_definitions)}") + + @classmethod + def apply(cls, root, **kwargs: Any) -> str: + transformed = apply_common_transforms( + root, + use_tmps=kwargs.get("use_tmps", False), + offset_provider=kwargs.get("offset_provider", None), + ) + generated_code = super().apply(transformed, **kwargs) + formatted_code = codegen.format_source("cpp", generated_code, style="LLVM") + return formatted_code + + +backend.register_backend( + "cpptoy", lambda prog, *args, **kwargs: print(ToyCpp.apply(prog, **kwargs)) +) diff --git a/src/iterator/backends/double_roundtrip.py b/src/iterator/backends/double_roundtrip.py new file mode 100644 index 0000000000..ea1d57c263 --- /dev/null +++ b/src/iterator/backends/double_roundtrip.py @@ -0,0 +1,9 @@ +from eve.concepts import Node +from iterator.backends import backend, embedded + + +def executor(ir: Node, *args, **kwargs): + embedded.executor(ir, *args, dispatch_backend=embedded._BACKEND_NAME, **kwargs) + + +backend.register_backend("double_roundtrip", executor) diff --git a/src/iterator/backends/embedded.py b/src/iterator/backends/embedded.py new file mode 100644 index 0000000000..718e5aa9af --- /dev/null +++ b/src/iterator/backends/embedded.py @@ -0,0 +1,154 @@ +import importlib.util +import tempfile + +import iterator +from eve import codegen +from eve.codegen import FormatTemplate as as_fmt +from eve.codegen import MakoTemplate as as_mako +from eve.concepts import Node +from iterator.backends import backend +from iterator.ir import AxisLiteral, FencilDefinition, OffsetLiteral +from iterator.transforms import apply_common_transforms + + +class EmbeddedDSL(codegen.TemplatedGenerator): + Sym = as_fmt("{id}") + SymRef = as_fmt("{id}") + BoolLiteral = as_fmt("{value}") + IntLiteral = as_fmt("{value}") + FloatLiteral = as_fmt("{value}") + NoneLiteral = as_fmt("None") + OffsetLiteral = as_fmt("{value}") + AxisLiteral = as_fmt("{value}") + StringLiteral = as_fmt("{value}") + FunCall = as_fmt("{fun}({','.join(args)})") + Lambda = as_mako("(lambda ${','.join(params)}: ${expr})") + StencilClosure = as_mako( + "closure(${domain}, ${stencil}, [${','.join(outputs)}], [${','.join(inputs)}])" + ) + FencilDefinition = as_mako( + """ +@fendef +def ${id}(${','.join(params)}): + ${'\\n '.join(closures)} + """ + ) + FunctionDefinition = as_mako( + """ +@fundef +def ${id}(${','.join(params)}): + return ${expr} + """ + ) + Program = as_fmt( + """ +{''.join(function_definitions)} {''.join(fencil_definitions)}""" + ) + + +# TODO this wrapper should be replaced by an extension of the IR +class WrapperGenerator(EmbeddedDSL): + def visit_FencilDefinition(self, node: FencilDefinition, *, tmps): + params = self.visit(node.params) + non_tmp_params = [param for param in params if param not in tmps] + + body = [] + for tmp, domain in tmps.items(): + axis_literals = [named_range.args[0].value for named_range in domain.args] + origin = ( + "{" + + ", ".join( + f"{named_range.args[0].value}: -{self.visit(named_range.args[1])}" + for named_range in domain.args + ) + + "}" + ) + shape = ( + "(" + + ", ".join( + f"{self.visit(named_range.args[2])}-{self.visit(named_range.args[1])}" + for named_range in domain.args + ) + + ")" + ) + body.append( + f"{tmp} = np_as_located_field({','.join(axis_literals)}, origin={origin})(np.full({shape}, np.nan))" + ) + + body.append(f"{node.id}({','.join(params)}, **kwargs)") + body = "\n ".join(body) + return f"\ndef {node.id}_wrapper({','.join(non_tmp_params)}, **kwargs):\n {body}\n" + + +_BACKEND_NAME = "embedded" + + +def executor(ir: Node, *args, **kwargs): + debug = "debug" in kwargs and kwargs["debug"] is True + use_tmps = "use_tmps" in kwargs and kwargs["use_tmps"] is True + + tmps = dict() + + def register_tmp(tmp, domain): + tmps[tmp] = domain + + ir = apply_common_transforms( + ir, use_tmps=use_tmps, offset_provider=kwargs["offset_provider"], register_tmp=register_tmp + ) + + program = EmbeddedDSL.apply(ir) + wrapper = WrapperGenerator.apply(ir, tmps=tmps) + offset_literals = ( + ir.iter_tree().if_isinstance(OffsetLiteral).getattr("value").if_isinstance(str).to_set() + ) + axis_literals = ir.iter_tree().if_isinstance(AxisLiteral).getattr("value").to_set() + with tempfile.NamedTemporaryFile( + mode="w", + suffix=".py", + delete=not debug, + ) as tmp: + if debug: + print(tmp.name) + header = """ +import numpy as np +from iterator.builtins import * +from iterator.runtime import * +from iterator.embedded import np_as_located_field +""" + offset_literals = [f'{o} = offset("{o}")' for o in offset_literals] + axis_literals = [f'{o} = CartesianAxis("{o}")' for o in axis_literals] + tmp.write(header) + tmp.write("\n".join(offset_literals)) + tmp.write("\n") + tmp.write("\n".join(axis_literals)) + tmp.write("\n") + tmp.write(program) + tmp.write(wrapper) + tmp.flush() + + spec = importlib.util.spec_from_file_location("module.name", tmp.name) + foo = importlib.util.module_from_spec(spec) + spec.loader.exec_module(foo) # type: ignore + + fencil_name = ir.fencil_definitions[0].id + fencil = getattr(foo, fencil_name + "_wrapper") + assert "offset_provider" in kwargs + + new_kwargs = {} + new_kwargs["offset_provider"] = kwargs["offset_provider"] + if "column_axis" in kwargs: + new_kwargs["column_axis"] = kwargs["column_axis"] + + if "dispatch_backend" not in kwargs: + iterator.builtins.builtin_dispatch.push_key("embedded") + fencil(*args, **new_kwargs) + iterator.builtins.builtin_dispatch.pop_key() + else: + fencil( + *args, + **new_kwargs, + backend=kwargs["dispatch_backend"], + ) + + +backend.register_backend(_BACKEND_NAME, executor) diff --git a/src/iterator/backends/lisp.py b/src/iterator/backends/lisp.py new file mode 100644 index 0000000000..acf32b08e6 --- /dev/null +++ b/src/iterator/backends/lisp.py @@ -0,0 +1,67 @@ +from typing import Any + +from eve.codegen import FormatTemplate as as_fmt +from eve.codegen import TemplatedGenerator +from iterator.backends import backend +from iterator.transforms import apply_common_transforms + + +class ToLispLike(TemplatedGenerator): + Sym = as_fmt("{id}") + FunCall = as_fmt("({fun} {' '.join(args)})") + IntLiteral = as_fmt("{value}") + OffsetLiteral = as_fmt("{value}") + StringLiteral = as_fmt("{value}") + SymRef = as_fmt("{id}") + Program = as_fmt( + """ + {''.join(function_definitions)} + {''.join(fencil_definitions)} + {''.join(setqs)} + """ + ) + StencilClosure = as_fmt( + """( + :domain {domain} + :stencil {stencil} + :outputs {' '.join(outputs)} + :inputs {' '.join(inputs)} + ) + """ + ) + FencilDefinition = as_fmt( + """(defen {id}({' '.join(params)}) + {''.join(closures)}) + """ + ) + FunctionDefinition = as_fmt( + """(defun {id}({' '.join(params)}) + {expr} + ) + +""" + ) + Lambda = as_fmt( + """(lambda ({' '.join(params)}) + {expr} + )""" + ) + + @classmethod + def apply(cls, root, **kwargs: Any) -> str: + transformed = apply_common_transforms( + root, use_tmps=kwargs.get("use_tmps", False), offset_provider=kwargs["offset_provider"] + ) + generated_code = super().apply(transformed, **kwargs) + try: + from yasi import indent_code + + indented = indent_code(generated_code, "--dialect lisp") + return "".join(indented["indented_code"]) + except ImportError: + return generated_code + + +backend.register_backend( + "lisp", lambda prog, *args, **kwargs: print(ToLispLike.apply(prog, **kwargs)) +) diff --git a/src/iterator/builtins.py b/src/iterator/builtins.py new file mode 100644 index 0000000000..c762178c75 --- /dev/null +++ b/src/iterator/builtins.py @@ -0,0 +1,130 @@ +from iterator.dispatcher import Dispatcher + + +__all__ = [ + "deref", + "shift", + "lift", + "reduce", + "scan", + "is_none", + "domain", + "named_range", + "compose", + "if_", + "or_", + "minus", + "plus", + "mul", + "div", + "eq", + "greater", + "make_tuple", + "nth", + "plus", + "reduce", + "scan", + "shift", +] + +builtin_dispatch = Dispatcher() + + +class BackendNotSelectedError(RuntimeError): + def __init__(self) -> None: + super().__init__("Backend not selected") + + +@builtin_dispatch +def deref(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def shift(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def lift(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def reduce(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def scan(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def is_none(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def domain(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def named_range(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def compose(sten): + raise BackendNotSelectedError() + + +@builtin_dispatch +def if_(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def or_(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def minus(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def plus(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def mul(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def div(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def eq(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def greater(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def make_tuple(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def nth(*args): + raise BackendNotSelectedError() diff --git a/src/iterator/dispatcher.py b/src/iterator/dispatcher.py new file mode 100644 index 0000000000..9f14f55b87 --- /dev/null +++ b/src/iterator/dispatcher.py @@ -0,0 +1,55 @@ +from typing import Any, Callable, Dict, List + + +# TODO test + + +class _fun_dispatcher: + def __init__(self, dispatcher, fun) -> None: + self.dispatcher = dispatcher + self.fun = fun + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + if self.dispatcher.key is None: + return self.fun(*args, **kwargs) + else: + return self.dispatcher._funs[self.dispatcher.key][self.fun.__name__](*args, **kwargs) + + def register(self, key): + self.dispatcher.register_key(key) + + def _impl(fun): + self.dispatcher._funs[key][self.fun.__name__] = fun + + return _impl + + +class Dispatcher: + def __init__(self) -> None: + self._funs: Dict[str, Dict[str, Callable]] = {} + self.key_stack: List[str] = [] + + @property + def key(self): + return self.key_stack[-1] if self.key_stack else None + + def register_key(self, key): + if key not in self._funs: + self._funs[key] = {} + + def push_key(self, key): + if key not in self._funs: + raise RuntimeError(f"Key {key} not registered") + self.key_stack.append(key) + + def pop_key(self): + self.key_stack.pop() + + def clear_key(self): + self.key_stack = [] + + def __call__(self, fun): + return self.dispatch(fun) + + def dispatch(self, fun): + return _fun_dispatcher(self, fun) diff --git a/src/iterator/embedded.py b/src/iterator/embedded.py new file mode 100644 index 0000000000..bfe0388e7e --- /dev/null +++ b/src/iterator/embedded.py @@ -0,0 +1,591 @@ +import itertools +import numbers +from dataclasses import dataclass + +import numpy as np + +import iterator +from iterator import builtins +from iterator.runtime import CartesianAxis, Offset +from iterator.utils import tupelize + + +EMBEDDED = "embedded" + + +class NeighborTableOffsetProvider: + def __init__(self, tbl, origin_axis, neighbor_axis, max_neighbors) -> None: + self.tbl = tbl + self.origin_axis = origin_axis + self.neighbor_axis = neighbor_axis + self.max_neighbors = max_neighbors + + +@builtins.deref.register(EMBEDDED) +def deref(it): + return it.deref() + + +@builtins.if_.register(EMBEDDED) +def if_(cond, t, f): + return t if cond else f + + +@builtins.or_.register(EMBEDDED) +def or_(a, b): + return a or b + + +@builtins.nth.register(EMBEDDED) +def nth(i, tup): + return tup[i] + + +@builtins.make_tuple.register(EMBEDDED) +def make_tuple(*args): + return (*args,) + + +@builtins.lift.register(EMBEDDED) +def lift(stencil): + def impl(*args): + class wrap_iterator: + def __init__(self, *, offsets=None, elem=None) -> None: + self.offsets = offsets or [] + self.elem = elem + + # TODO needs to be supported by all iterators that represent tuples + def __getitem__(self, index): + return wrap_iterator(offsets=self.offsets, elem=index) + + def shift(self, *offsets): + return wrap_iterator(offsets=[*self.offsets, *offsets], elem=self.elem) + + def max_neighbors(self): + # TODO cleanup, test edge cases + open_offsets = get_open_offsets(*self.offsets) + assert open_offsets + assert isinstance( + args[0].offset_provider[open_offsets[0].value], + NeighborTableOffsetProvider, + ) + return args[0].offset_provider[open_offsets[0].value].max_neighbors + + def deref(self): + shifted_args = tuple(map(lambda arg: arg.shift(*self.offsets), args)) + + if any(shifted_arg.is_none() for shifted_arg in shifted_args): + return None + + if self.elem is None: + return stencil(*shifted_args) + else: + return stencil(*shifted_args)[self.elem] + + return wrap_iterator() + + return impl + + +@builtins.reduce.register(EMBEDDED) +def reduce(fun, init): + def sten(*iters): + # TODO: assert check_that_all_iterators_are_compatible(*iters) + first_it = iters[0] + n = first_it.max_neighbors() + res = init + for i in range(n): + # we can check a single argument + # because all arguments share the same pattern + if builtins.deref(builtins.shift(i)(first_it)) is None: + break + res = fun( + res, + *(builtins.deref(builtins.shift(i)(it)) for it in iters), + ) + return res + + return sten + + +class _None: + """Dummy object to allow execution of expression containing Nones in non-active path. + + E.g. + `if_(is_none(state), 42, 42+state)` + here 42+state needs to be evaluatable even if is_none(state) + + TODO: all possible arithmetic operations + """ + + def __add__(self, other): + return _None() + + def __radd__(self, other): + return _None() + + def __sub__(self, other): + return _None() + + def __rsub__(self, other): + return _None() + + def __mul__(self, other): + return _None() + + def __rmul__(self, other): + return _None() + + def __truediv__(self, other): + return _None() + + def __rtruediv__(self, other): + return _None() + + def __getitem__(self, i): + return _None() + + +@builtins.is_none.register(EMBEDDED) +def is_none(arg): + return isinstance(arg, _None) + + +@builtins.domain.register(EMBEDDED) +def domain(*args): + domain = {} + for arg in args: + domain.update(arg) + return domain + + +@builtins.named_range.register(EMBEDDED) +def named_range(tag, start, end): + return {tag: range(start, end)} + + +@builtins.minus.register(EMBEDDED) +def minus(first, second): + return first - second + + +@builtins.plus.register(EMBEDDED) +def plus(first, second): + return first + second + + +@builtins.mul.register(EMBEDDED) +def mul(first, second): + return first * second + + +@builtins.div.register(EMBEDDED) +def div(first, second): + return first / second + + +@builtins.eq.register(EMBEDDED) +def eq(first, second): + return first == second + + +@builtins.greater.register(EMBEDDED) +def greater(first, second): + return first > second + + +def named_range_(axis, range_): + return ((axis, i) for i in range_) + + +def domain_iterator(domain): + return ( + dict(elem) + for elem in itertools.product( + *map(lambda tup: named_range_(tup[0], tup[1]), domain.items()) + ) + ) + + +def execute_shift(pos, tag, index, *, offset_provider): + if tag in pos and pos[tag] is None: # sparse field with offset as neighbor dimension + new_pos = pos.copy() + new_pos[tag] = index + return new_pos + assert tag.value in offset_provider + offset_implementation = offset_provider[tag.value] + if isinstance(offset_implementation, CartesianAxis): + assert offset_implementation in pos + new_pos = pos.copy() + new_pos[offset_implementation] += index + return new_pos + elif isinstance(offset_implementation, NeighborTableOffsetProvider): + assert offset_implementation.origin_axis in pos + new_pos = pos.copy() + del new_pos[offset_implementation.origin_axis] + if offset_implementation.tbl[pos[offset_implementation.origin_axis], index] is None: + return None + else: + new_pos[offset_implementation.neighbor_axis] = offset_implementation.tbl[ + pos[offset_implementation.origin_axis], index + ] + return new_pos + + raise AssertionError() + + +# The following holds for shifts: +# shift(tag, index)(inp) -> full shift +# shift(tag)(inp) -> incomplete shift +# shift(index)(shift(tag)(inp)) -> full shift +# Therefore the following transformation holds +# shift(e2v,0)(shift(c2e,2)(cell_field)) +# = shift(0)(shift(e2v)(shift(2)(shift(c2e)(cell_field)))) +# = shift(c2e, 2, e2v, 0)(cell_field) +# = shift(c2e,e2v,2,0)(cell_field) <-- v2c,e2c twice incomplete shift +# = shift(2,0)(shift(c2e,e2v)(cell_field)) +# for implementations it means everytime we have an index, we can "execute" a concrete shift +def group_offsets(*offsets): + tag_stack = [] + complete_offsets = [] + for offset in offsets: + if not isinstance(offset, int): + tag_stack.append(offset) + else: + assert tag_stack + tag = tag_stack.pop(0) + complete_offsets.append((tag, offset)) + return complete_offsets, tag_stack + + +def shift_position(pos, *complete_offsets, offset_provider): + new_pos = pos.copy() + for tag, index in complete_offsets: + new_pos = execute_shift(new_pos, tag, index, offset_provider=offset_provider) + if new_pos is None: + return None + return new_pos + + +def get_open_offsets(*offsets): + return group_offsets(*offsets)[1] + + +class Undefined: + def __float__(self): + return np.nan + + @classmethod + def _setup_math_operations(cls): + ops = [ + "__add__", + "__sub__", + "__mul__", + "__matmul__", + "__truediv__", + "__floordiv__", + "__mod__", + "__divmod__", + "__pow__", + "__lshift__", + "__rshift__", + "__and__", + "__xor__", + "__or__", + "__radd__", + "__rsub__", + "__rmul__", + "__rmatmul__", + "__rtruediv__", + "__rfloordiv__", + "__rmod__", + "__rdivmod__", + "__rpow__", + "__rlshift__", + "__rrshift__", + "__rand__", + "__rxor__", + "__ror__", + "__neg__", + "__pos__", + "__abs__", + "__invert__", + ] + for op in ops: + setattr(cls, op, lambda self, *args, **kwargs: _UNDEFINED) + + +Undefined._setup_math_operations() + +_UNDEFINED = Undefined() + + +class MDIterator: + def __init__( + self, field, pos, *, incomplete_offsets=None, offset_provider, column_axis=None + ) -> None: + self.field = field + self.pos = pos + self.incomplete_offsets = incomplete_offsets or [] + self.offset_provider = offset_provider + self.column_axis = column_axis + + def shift(self, *offsets): + complete_offsets, open_offsets = group_offsets(*self.incomplete_offsets, *offsets) + return MDIterator( + self.field, + shift_position(self.pos, *complete_offsets, offset_provider=self.offset_provider), + incomplete_offsets=open_offsets, + offset_provider=self.offset_provider, + column_axis=self.column_axis, + ) + + def max_neighbors(self): + assert self.incomplete_offsets + assert isinstance( + self.offset_provider[self.incomplete_offsets[0].value], NeighborTableOffsetProvider + ) + return self.offset_provider[self.incomplete_offsets[0].value].max_neighbors + + def is_none(self): + return self.pos is None + + def deref(self): + shifted_pos = self.pos.copy() + + if not all(axis in shifted_pos.keys() for axis in self.field.axises): + raise IndexError("Iterator position doesn't point to valid location for its field.") + slice_column = {} + if self.column_axis is not None: + slice_column[self.column_axis] = slice(shifted_pos[self.column_axis], None) + del shifted_pos[self.column_axis] + ordered_indices = get_ordered_indices( + self.field.axises, + shifted_pos, + slice_axises=slice_column, + ) + try: + return self.field[ordered_indices] + except IndexError: + return _UNDEFINED + + +def make_in_iterator(inp, pos, offset_provider, *, column_axis): + sparse_dimensions = [axis for axis in inp.axises if isinstance(axis, Offset)] + assert len(sparse_dimensions) <= 1 # TODO multiple is not a current use case + new_pos = pos.copy() + for axis in sparse_dimensions: + new_pos[axis] = None + if column_axis is not None: + # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted + new_pos[column_axis] = 0 + return MDIterator( + inp, + new_pos, + incomplete_offsets=[*sparse_dimensions], + offset_provider=offset_provider, + column_axis=column_axis, + ) + + +builtins.builtin_dispatch.push_key(EMBEDDED) # makes embedded the default + + +class LocatedField: + """A Field with named dimensions/axises. + + Axis keys can be any objects that are hashable. + """ + + def __init__(self, getter, axises, *, setter=None, array=None): + self.getter = getter + self.axises = axises + self.setter = setter + self.array = array + + def __getitem__(self, indices): + indices = tupelize(indices) + return self.getter(indices) + + def __setitem__(self, indices, value): + if self.setter is None: + raise TypeError("__setitem__ not supported for this field") + self.setter(indices, value) + + def __array__(self): + if self.array is None: + raise TypeError("__array__ not supported for this field") + return self.array() + + @property + def shape(self): + if self.array is None: + raise TypeError("`shape` not supported for this field") + return self.array().shape + + +def get_ordered_indices(axises, pos, *, slice_axises=None): + """pos is a dictionary from axis to offset.""" # noqa: D403 + slice_axises = slice_axises or dict() + assert all(axis in [*pos.keys(), *slice_axises] for axis in axises) + return tuple(pos[axis] if axis in pos else slice_axises[axis] for axis in axises) + + +def _tupsum(a, b): + def combine_slice(s, t): + is_slice = False + if isinstance(s, slice): + is_slice = True + first = s.start + assert s.step is None + assert s.stop is None + else: + assert isinstance(s, numbers.Integral) + first = s + if isinstance(t, slice): + is_slice = True + second = t.start + assert t.step is None + assert t.stop is None + else: + assert isinstance(t, numbers.Integral) + second = t + start = first + second + return slice(start, None) if is_slice else start + + return tuple(combine_slice(*i) for i in zip(a, b)) + + +def np_as_located_field(*axises, origin=None): + def _maker(a: np.ndarray): + if a.ndim != len(axises): + raise TypeError("ndarray.ndim incompatible with number of given axises") + + if origin is not None: + offsets = get_ordered_indices(axises, origin) + else: + offsets = tuple(0 for _ in axises) + + def setter(indices, value): + a[_tupsum(indices, offsets)] = value + + def getter(indices): + return a[_tupsum(indices, offsets)] + + return LocatedField(getter, axises, setter=setter, array=a.__array__) + + return _maker + + +def index_field(axis): + return LocatedField(lambda index: index[0], (axis,)) + + +@builtins.shift.register(EMBEDDED) +def shift(*offsets): + def impl(it): + return it.shift(*offsets) + + return impl + + +@dataclass +class Column: + axis: CartesianAxis + range: range # noqa: A003 + + +class ScanArgIterator: + def __init__(self, wrapped_iter, k_pos, *, offsets=None) -> None: + self.wrapped_iter = wrapped_iter + self.offsets = offsets or [] + self.k_pos = k_pos + + def deref(self): + return self.wrapped_iter.deref()[self.k_pos] + + def shift(self, *offsets): + return ScanArgIterator(self.wrapped_iter, offsets=[*offsets, *self.offsets]) + + +def shifted_scan_arg(k_pos): + def impl(it): + return ScanArgIterator(it, k_pos=k_pos) + + return impl + + +def fendef_embedded(fun, *args, **kwargs): # noqa: 536 + assert "offset_provider" in kwargs + + @iterator.runtime.closure.register(EMBEDDED) + def closure(domain, sten, outs, ins): # domain is Dict[axis, range] + + column = None + if "column_axis" in kwargs: + _column_axis = kwargs["column_axis"] + column = Column(_column_axis, domain[_column_axis]) + del domain[_column_axis] + + @builtins.scan.register( + EMBEDDED + ) # TODO this is a bit ugly, alternative: pass scan range via iterator + def scan(scan_pass, is_forward, init): + def impl(*iters): + if column is None: + raise RuntimeError("Column axis is not defined, cannot scan.") + + _range = column.range + if not is_forward: + _range = reversed(_range) + + state = init + if state is None: + state = _None() + cols = [] + for i in _range: + state = scan_pass( + state, *map(shifted_scan_arg(i), iters) + ) # more generic scan returns state and result as 2 different things + cols.append([*tupelize(state)]) + + cols = tuple(map(np.asarray, (map(list, zip(*cols))))) + # transpose to get tuple of columns as np array + + if not is_forward: + cols = tuple(map(np.flip, cols)) + return cols + + return impl + + for pos in domain_iterator(domain): + ins_iters = list( + make_in_iterator( + inp, + pos, + kwargs["offset_provider"], + column_axis=column.axis if column is not None else None, + ) + for inp in ins + ) + res = sten(*ins_iters) + if not isinstance(res, tuple): + res = (res,) + if not len(res) == len(outs): + IndexError("Number of return values doesn't match number of output fields.") + + for r, out in zip(res, outs): + if column is None: + ordered_indices = get_ordered_indices(out.axises, pos) + out[ordered_indices] = r + else: + colpos = pos.copy() + for k in column.range: + colpos[column.axis] = k + ordered_indices = get_ordered_indices(out.axises, colpos) + out[ordered_indices] = r[k] + + fun(*args) + + +iterator.runtime.fendef_registry[None] = fendef_embedded diff --git a/src/iterator/ir.py b/src/iterator/ir.py new file mode 100644 index 0000000000..58ca48d52f --- /dev/null +++ b/src/iterator/ir.py @@ -0,0 +1,119 @@ +from typing import List, Union + +from eve import Node +from eve.traits import SymbolName, SymbolTableTrait +from eve.type_definitions import SymbolRef +from iterator.util.sym_validation import validate_symbol_refs + + +class Sym(Node): # helper + id: SymbolName # noqa: A003 + + +class Expr(Node): + ... + + +class BoolLiteral(Expr): + value: bool + + +class IntLiteral(Expr): + value: int + + +class FloatLiteral(Expr): + value: float # TODO other float types + + +class StringLiteral(Expr): + value: str + + +class NoneLiteral(Expr): + _none_literal: int = 0 + + +class OffsetLiteral(Expr): + value: Union[int, str] + + +class AxisLiteral(Expr): + value: str + + +class SymRef(Expr): + id: SymbolRef # noqa: A003 + + +class Lambda(Expr, SymbolTableTrait): + params: List[Sym] + expr: Expr + + +class FunCall(Expr): + fun: Expr # VType[Callable] + args: List[Expr] + + +class FunctionDefinition(Node, SymbolTableTrait): + id: SymbolName # noqa: A003 + params: List[Sym] + expr: Expr + + def __eq__(self, other): + return isinstance(other, FunctionDefinition) and self.id == other.id + + def __hash__(self): + return hash(self.id) + + +class Setq(Node): + id: SymbolName # noqa: A003 + expr: Expr + + +class StencilClosure(Node): + domain: Expr + stencil: Expr + outputs: List[SymRef] + inputs: List[SymRef] + + +class FencilDefinition(Node, SymbolTableTrait): + id: SymbolName # noqa: A003 + params: List[Sym] + closures: List[StencilClosure] + + +class Program(Node, SymbolTableTrait): + function_definitions: List[FunctionDefinition] + fencil_definitions: List[FencilDefinition] + setqs: List[Setq] + + builtin_functions = list( + Sym(id=name) + for name in [ + "domain", + "named_range", + "compose", + "lift", + "is_none", + "make_tuple", + "nth", + "reduce", + "deref", + "shift", + "scan", + "plus", + "minus", + "mul", + "div", + "eq", + "greater", + "less", + "if_", + "or_", + ] + ) + _validate_symbol_refs = validate_symbol_refs() diff --git a/src/iterator/library.py b/src/iterator/library.py new file mode 100644 index 0000000000..18c87aeb49 --- /dev/null +++ b/src/iterator/library.py @@ -0,0 +1,12 @@ +from iterator.builtins import reduce + + +def sum_(fun=None): + if fun is None: + return reduce(lambda a, b: a + b, 0) + else: + return reduce(lambda first, a, b: first + fun(a, b), 0) # TODO tracing for *args + + +def dot(a, b): + return reduce(lambda acc, a_n, c_n: acc + a_n * c_n, 0)(a, b) diff --git a/src/iterator/runtime.py b/src/iterator/runtime.py new file mode 100644 index 0000000000..deca4a6e33 --- /dev/null +++ b/src/iterator/runtime.py @@ -0,0 +1,85 @@ +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Union + +from iterator.builtins import BackendNotSelectedError, builtin_dispatch + + +__all__ = ["offset", "fundef", "fendef", "closure", "CartesianAxis"] + + +@dataclass +class Offset: + value: Optional[Union[int, str]] = None + + def __hash__(self) -> int: + return hash(self.value) + + +def offset(value): + return Offset(value) + + +@dataclass +class CartesianAxis: + value: str + + def __hash__(self) -> int: + return hash(self.value) + + +fendef_registry: Dict[Optional[Callable], Callable] = {} + + +# TODO the dispatching is linear, not sure if there is an easy way to make it constant +def fendef(*dec_args, **dec_kwargs): + def wrapper(fun): + def impl(*args, **kwargs): + kwargs = {**kwargs, **dec_kwargs} + + for key, val in fendef_registry.items(): + if key is not None and key(kwargs): + val(fun, *args, **kwargs) + return + if None in fendef_registry: + fendef_registry[None](fun, *args, **kwargs) + return + raise RuntimeError("Unreachable") + + return impl + + if len(dec_args) == 1 and len(dec_kwargs) == 0 and callable(dec_args[0]): + return wrapper(dec_args[0]) + else: + assert len(dec_args) == 0 + return wrapper + + +class FundefDispatcher: + _hook = None + # hook is an object that + # - evaluates to true if it should be used, + # - is callable with an instance of FundefDispatcher + # - returns callable that takes the function arguments + + def __init__(self, fun) -> None: + self.fun = fun + self.__name__ = fun.__name__ + + def __call__(self, *args): + if type(self)._hook: + return type(self)._hook(self)(*args) + else: + return self.fun(*args) + + @classmethod + def register_hook(cls, hook): + cls._hook = hook + + +def fundef(fun): + return FundefDispatcher(fun) + + +@builtin_dispatch +def closure(*args): + return BackendNotSelectedError() diff --git a/src/iterator/tracing.py b/src/iterator/tracing.py new file mode 100644 index 0000000000..49731d23ef --- /dev/null +++ b/src/iterator/tracing.py @@ -0,0 +1,317 @@ +import inspect +from typing import List + +import iterator +from eve import Node +from iterator.backend_executor import execute_program +from iterator.ir import ( + AxisLiteral, + BoolLiteral, + Expr, + FencilDefinition, + FloatLiteral, + FunCall, + FunctionDefinition, + IntLiteral, + Lambda, + NoneLiteral, + OffsetLiteral, + Program, + StencilClosure, + Sym, + SymRef, +) +from iterator.runtime import CartesianAxis + + +TRACING = "tracing" + + +def monkeypatch_method(cls): + def decorator(func): + setattr(cls, func.__name__, func) + return func + + return decorator + + +def _patch_Expr(): + @monkeypatch_method(Expr) + def __add__(self, other): + return FunCall(fun=SymRef(id="plus"), args=[self, make_node(other)]) + + @monkeypatch_method(Expr) + def __radd__(self, other): + return make_node(other) + self + + @monkeypatch_method(Expr) + def __mul__(self, other): + return FunCall(fun=SymRef(id="mul"), args=[self, make_node(other)]) + + @monkeypatch_method(Expr) + def __rmul__(self, other): + return make_node(other) * self + + @monkeypatch_method(Expr) + def __truediv__(self, other): + return FunCall(fun=SymRef(id="div"), args=[self, make_node(other)]) + + @monkeypatch_method(Expr) + def __sub__(self, other): + return FunCall(fun=SymRef(id="minus"), args=[self, make_node(other)]) + + @monkeypatch_method(Expr) + def __gt__(self, other): + return FunCall(fun=SymRef(id="greater"), args=[self, make_node(other)]) + + @monkeypatch_method(Expr) + def __lt__(self, other): + return FunCall(fun=SymRef(id="less"), args=[self, make_node(other)]) + + @monkeypatch_method(Expr) + def __call__(self, *args): + return FunCall(fun=self, args=[*make_node(args)]) + + +_patch_Expr() + + +class PatchedFunctionDefinition(FunctionDefinition): + def __call__(self, *args): + return FunCall(fun=SymRef(id=str(self.id)), args=[*make_node(args)]) + + +def _s(id_): + return SymRef(id=id_) + + +def trace_function_argument(arg): + if isinstance(arg, iterator.runtime.FundefDispatcher): + make_function_definition(arg.fun) + return _s(arg.fun.__name__) + return arg + + +def _f(fun, *args): + if isinstance(fun, str): + fun = _s(fun) + + args = [trace_function_argument(arg) for arg in args] + return FunCall(fun=fun, args=[*make_node(args)]) + + +# builtins +@iterator.builtins.deref.register(TRACING) +def deref(arg): + return _f("deref", arg) + + +@iterator.builtins.lift.register(TRACING) +def lift(sten): + return _f("lift", sten) + + +@iterator.builtins.reduce.register(TRACING) +def reduce(*args): + return _f("reduce", *args) + + +@iterator.builtins.scan.register(TRACING) +def scan(*args): + return _f("scan", *args) + + +@iterator.builtins.is_none.register(TRACING) +def is_none(*args): + return _f("is_none", *args) + + +@iterator.builtins.make_tuple.register(TRACING) +def make_tuple(*args): + return _f("make_tuple", *args) + + +@iterator.builtins.nth.register(TRACING) +def nth(*args): + return _f("nth", *args) + + +@iterator.builtins.compose.register(TRACING) +def compose(*args): + return _f("compose", *args) + + +@iterator.builtins.domain.register(TRACING) +def domain(*args): + return _f("domain", *args) + + +@iterator.builtins.named_range.register(TRACING) +def named_range(*args): + return _f("named_range", *args) + + +@iterator.builtins.if_.register(TRACING) +def if_(*args): + return _f("if_", *args) + + +@iterator.builtins.or_.register(TRACING) +def or_(*args): + return _f("or_", *args) + + +# shift promotes its arguments to literals, therefore special +@iterator.builtins.shift.register(TRACING) +def shift(*offsets): + offsets = tuple(OffsetLiteral(value=o) if isinstance(o, (str, int)) else o for o in offsets) + return _f("shift", *offsets) + + +@iterator.builtins.plus.register(TRACING) +def plus(*args): + return _f("plus", *args) + + +@iterator.builtins.minus.register(TRACING) +def minus(*args): + return _f("minus", *args) + + +@iterator.builtins.mul.register(TRACING) +def mul(*args): + return _f("mul", *args) + + +@iterator.builtins.div.register(TRACING) +def div(*args): + return _f("div", *args) + + +@iterator.builtins.eq.register(TRACING) +def eq(*args): + return _f("eq", *args) + + +@iterator.builtins.greater.register(TRACING) +def greater(*args): + return _f("greater", *args) + + +# helpers +def make_node(o): + if isinstance(o, Node): + return o + if callable(o): + if o.__name__ == "": + return lambdadef(o) + if hasattr(o, "__code__") and o.__code__.co_flags & inspect.CO_NESTED: + return lambdadef(o) + if isinstance(o, iterator.runtime.Offset): + return OffsetLiteral(value=o.value) + if isinstance(o, bool): + return BoolLiteral(value=o) + if isinstance(o, int): + return IntLiteral(value=o) + if isinstance(o, float): + return FloatLiteral(value=o) + if isinstance(o, CartesianAxis): + return AxisLiteral(value=o.value) + if isinstance(o, tuple): + return tuple(make_node(arg) for arg in o) + if isinstance(o, list): + return list(make_node(arg) for arg in o) + if o is None: + return NoneLiteral() + if isinstance(o, iterator.runtime.FundefDispatcher): + return SymRef(id=o.fun.__name__) + raise NotImplementedError(f"Cannot handle {o}") + + +def trace_function_call(fun): + body = fun(*list(_s(param) for param in inspect.signature(fun).parameters.keys())) + return make_node(body) if body is not None else None + + +def lambdadef(fun): + return Lambda( + params=list(Sym(id=param) for param in inspect.signature(fun).parameters.keys()), + expr=trace_function_call(fun), + ) + + +def make_function_definition(fun): + res = PatchedFunctionDefinition( + id=fun.__name__, + params=list(Sym(id=param) for param in inspect.signature(fun).parameters.keys()), + expr=trace_function_call(fun), + ) + Tracer.add_fundef(res) + return res + + +class FundefTracer: + def __call__(self, fundef_dispatcher: iterator.runtime.FundefDispatcher): + def fun(*args): + res = make_function_definition(fundef_dispatcher.fun) + return res(*args) + + return fun + + def __bool__(self): + return iterator.builtins.builtin_dispatch.key == TRACING + + +iterator.runtime.FundefDispatcher.register_hook(FundefTracer()) + + +class Tracer: + fundefs: List[FunctionDefinition] = [] + closures: List[StencilClosure] = [] + + @classmethod + def add_fundef(cls, fun): + if fun not in cls.fundefs: + cls.fundefs.append(fun) + + @classmethod + def add_closure(cls, closure): + cls.closures.append(closure) + + def __enter__(self): + iterator.builtins.builtin_dispatch.push_key(TRACING) + + def __exit__(self, exc_type, exc_value, exc_traceback): + type(self).fundefs = [] + type(self).closures = [] + iterator.builtins.builtin_dispatch.pop_key() + + +@iterator.runtime.closure.register(TRACING) +def closure(domain, stencil, outputs, inputs): + stencil(*list(_s(param) for param in inspect.signature(stencil).parameters.keys())) + Tracer.add_closure( + StencilClosure( + domain=domain, + stencil=make_node(stencil), + outputs=outputs, + inputs=inputs, + ) + ) + + +def fendef_tracing(fun, *args, **kwargs): + with Tracer() as _: + trace_function_call(fun) + + fencil = FencilDefinition( + id=fun.__name__, + params=list(Sym(id=param) for param in inspect.signature(fun).parameters.keys()), + closures=Tracer.closures, + ) + prog = Program(function_definitions=Tracer.fundefs, fencil_definitions=[fencil], setqs=[]) + # after tracing is done + execute_program(prog, *args, **kwargs) + + +iterator.runtime.fendef_registry[lambda kwargs: "backend" in kwargs] = fendef_tracing diff --git a/src/iterator/transforms/__init__.py b/src/iterator/transforms/__init__.py new file mode 100644 index 0000000000..1fc5154fa0 --- /dev/null +++ b/src/iterator/transforms/__init__.py @@ -0,0 +1,4 @@ +from iterator.transforms.common import apply_common_transforms + + +__all__ = ["apply_common_transforms"] diff --git a/src/iterator/transforms/collect_shifts.py b/src/iterator/transforms/collect_shifts.py new file mode 100644 index 0000000000..9e326295de --- /dev/null +++ b/src/iterator/transforms/collect_shifts.py @@ -0,0 +1,31 @@ +from typing import Dict, List + +from eve import NodeVisitor +from iterator import ir + + +class CollectShifts(NodeVisitor): + def visit_FunCall(self, node: ir.FunCall, *, shifts: Dict[str, List[tuple]]): + if isinstance(node.fun, ir.SymRef) and node.fun.id == "deref": + assert len(node.args) == 1 + arg = node.args[0] + if isinstance(arg, ir.SymRef): + # direct deref of a symbol: deref(sym) + shifts.setdefault(arg.id, []).append(()) + elif ( + isinstance(arg, ir.FunCall) + and isinstance(arg.fun, ir.FunCall) + and isinstance(arg.fun.fun, ir.SymRef) + and arg.fun.fun.id == "shift" + and isinstance(arg.args[0], ir.SymRef) + ): + # deref of a shifted symbol: deref(shift(...)(sym)) + assert len(arg.args) == 1 + sym = arg.args[0] + shift_args = arg.fun.args + shifts.setdefault(sym.id, []).append(tuple(shift_args)) + else: + raise RuntimeError(f"Unexpected node: {node}") + elif isinstance(node.fun, ir.SymRef) and node.fun.id in ("lift", "scan"): + raise RuntimeError(f"Unsupported node: {node}") + return self.generic_visit(node, shifts=shifts) diff --git a/src/iterator/transforms/common.py b/src/iterator/transforms/common.py new file mode 100644 index 0000000000..349b5ad237 --- /dev/null +++ b/src/iterator/transforms/common.py @@ -0,0 +1,21 @@ +from iterator.transforms.global_tmps import CreateGlobalTmps +from iterator.transforms.inline_fundefs import InlineFundefs, PruneUnreferencedFundefs +from iterator.transforms.inline_lambdas import InlineLambdas +from iterator.transforms.inline_lifts import InlineLifts +from iterator.transforms.normalize_shifts import NormalizeShifts + + +def apply_common_transforms(ir, use_tmps=False, offset_provider=None, register_tmp=None): + ir = InlineFundefs().visit(ir) + ir = PruneUnreferencedFundefs().visit(ir) + ir = NormalizeShifts().visit(ir) + if not use_tmps: + ir = InlineLifts().visit(ir) + ir = InlineLambdas().visit(ir) + ir = NormalizeShifts().visit(ir) + if use_tmps: + assert offset_provider is not None + ir = CreateGlobalTmps().visit( + ir, offset_provider=offset_provider, register_tmp=register_tmp + ) + return ir diff --git a/src/iterator/transforms/global_tmps.py b/src/iterator/transforms/global_tmps.py new file mode 100644 index 0000000000..2d83524159 --- /dev/null +++ b/src/iterator/transforms/global_tmps.py @@ -0,0 +1,121 @@ +from typing import Dict, List + +from eve import NodeTranslator +from iterator import ir +from iterator.runtime import CartesianAxis +from iterator.transforms.collect_shifts import CollectShifts +from iterator.transforms.popup_tmps import PopupTmps + + +class CreateGlobalTmps(NodeTranslator): + @staticmethod + def _extend_domain(domain: ir.FunCall, offset_provider, shifts): + assert isinstance(domain.fun, ir.SymRef) and domain.fun.id == "domain" + assert all(isinstance(o, CartesianAxis) for o in offset_provider.values()) + + offset_limits = {k: (0, 0) for k in offset_provider.keys()} + + for shift in shifts: + offsets = {k: 0 for k in offset_provider.keys()} + for k, v in zip(shift[0::2], shift[1::2]): + offsets[k.value] += v.value + for k, v in offsets.items(): + old_min, old_max = offset_limits[k] + offset_limits[k] = (min(old_min, v), max(old_max, v)) + + offset_limits = {v.value: offset_limits[k] for k, v in offset_provider.items()} + + named_ranges = [] + for named_range in domain.args: + assert ( + isinstance(named_range, ir.FunCall) + and isinstance(named_range.fun, ir.SymRef) + and named_range.fun.id == "named_range" + ) + axis_literal, lower_bound, upper_bound = named_range.args + assert isinstance(axis_literal, ir.AxisLiteral) + + lower_offset, upper_offset = offset_limits.get(axis_literal.value, (0, 0)) + named_ranges.append( + ir.FunCall( + fun=named_range.fun, + args=[ + axis_literal, + ir.FunCall( + fun=ir.SymRef(id="plus"), + args=[lower_bound, ir.IntLiteral(value=lower_offset)], + ) + if lower_offset + else lower_bound, + ir.FunCall( + fun=ir.SymRef(id="plus"), + args=[upper_bound, ir.IntLiteral(value=upper_offset)], + ) + if upper_offset + else upper_bound, + ], + ) + ) + + return ir.FunCall(fun=domain.fun, args=named_ranges) + + def visit_FencilDefinition(self, node: ir.FencilDefinition, *, offset_provider, register_tmp): + tmps: List[ir.Sym] = [] + + def handle_arg(arg): + if isinstance(arg, ir.SymRef): + return arg + if ( + isinstance(arg, ir.FunCall) + and isinstance(arg.fun, ir.FunCall) + and arg.fun.fun.id == "lift" + ): + ref = ir.SymRef(id=f"tmp{len(tmps)}") + tmps.append(ir.Sym(id=ref.id)) + assert len(arg.fun.args) == 1 + unlifted = ir.FunCall(fun=arg.fun.args[0], args=arg.args) + todos.append(([ref], unlifted)) + return ref + raise AssertionError() + + closures = [] + tmp_domains = dict() + for closure in reversed(node.closures): + assert isinstance(closure.stencil, ir.Lambda) + wrapped_stencil = ir.FunCall(fun=closure.stencil, args=closure.inputs) + popped_stencil = PopupTmps().visit(wrapped_stencil) + todos = [(closure.outputs, popped_stencil)] + + shifts: Dict[str, List[tuple]] = dict() + domain = closure.domain + while todos: + outputs, call = todos.pop() + output_shifts: List[tuple] = sum((shifts.get(o.id, []) for o in outputs), []) + domain = self._extend_domain(domain, offset_provider, output_shifts) + for output in outputs: + if output.id in {tmp.id for tmp in tmps}: + assert output.id not in tmp_domains + tmp_domains[output.id] = domain + closure = ir.StencilClosure( + domain=domain, + stencil=call.fun, + outputs=outputs, + inputs=[handle_arg(arg) for arg in call.args], + ) + local_shifts: Dict[str, List[tuple]] = dict() + CollectShifts().visit(closure.stencil, shifts=local_shifts) + input_map = { + param.id: arg.id for param, arg in zip(closure.stencil.params, closure.inputs) + } + for k, v in local_shifts.items(): + shifts.setdefault(input_map[k], []).extend(v) + closures.append(closure) + + assert {tmp.id for tmp in tmps} == set(tmp_domains.keys()) + if register_tmp is not None: + for tmp, domain in tmp_domains.items(): + register_tmp(tmp, domain) + + return ir.FencilDefinition( + id=node.id, params=node.params + tmps, closures=list(reversed(closures)) + ) diff --git a/src/iterator/transforms/inline_fundefs.py b/src/iterator/transforms/inline_fundefs.py new file mode 100644 index 0000000000..308fef6953 --- /dev/null +++ b/src/iterator/transforms/inline_fundefs.py @@ -0,0 +1,35 @@ +from typing import Any, Dict, Set + +from eve import NOTHING, NodeTranslator +from iterator import ir + + +class InlineFundefs(NodeTranslator): + def visit_SymRef(self, node: ir.SymRef, *, symtable: Dict[str, Any]): + if node.id in symtable and isinstance((symbol := symtable[node.id]), ir.FunctionDefinition): + return ir.Lambda( + params=self.generic_visit(symbol.params, symtable=symtable), + expr=self.generic_visit(symbol.expr, symtable=symtable), + ) + return self.generic_visit(node) + + def visit_Program(self, node: ir.Program): + return self.generic_visit(node, symtable=node.symtable_) + + +class PruneUnreferencedFundefs(NodeTranslator): + def visit_FunctionDefinition( + self, node: ir.FunctionDefinition, *, referenced: Set[str], second_pass: bool + ): + if second_pass and node.id not in referenced: + return NOTHING + return self.generic_visit(node, referenced=referenced, second_pass=second_pass) + + def visit_SymRef(self, node: ir.SymRef, *, referenced: Set[str], second_pass: bool): + referenced.add(node.id) + return node + + def visit_Program(self, node: ir.Program): + referenced: Set[str] = set() + self.generic_visit(node, referenced=referenced, second_pass=False) + return self.generic_visit(node, referenced=referenced, second_pass=True) diff --git a/src/iterator/transforms/inline_lambdas.py b/src/iterator/transforms/inline_lambdas.py new file mode 100644 index 0000000000..c38627acad --- /dev/null +++ b/src/iterator/transforms/inline_lambdas.py @@ -0,0 +1,32 @@ +from eve import NodeTranslator +from iterator import ir +from iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols + + +class InlineLambdas(NodeTranslator): + def visit_FunCall(self, node: ir.FunCall): + node = self.generic_visit(node) + if isinstance(node.fun, ir.Lambda): + assert len(node.fun.params) == len(node.args) + refs = set.union( + *( + arg.iter_tree().if_isinstance(ir.SymRef).getattr("id").to_set() + for arg in node.args + ) + ) + syms = node.fun.expr.iter_tree().if_isinstance(ir.Sym).getattr("id").to_set() + clashes = refs & syms + expr = node.fun.expr + if clashes: + + def new_name(name): + while name in refs or name in syms: + name += "_" + return name + + name_map = {sym: new_name(sym) for sym in clashes} + expr = RenameSymbols().visit(expr, name_map=name_map) + + symbol_map = {param.id: arg for param, arg in zip(node.fun.params, node.args)} + return RemapSymbolRefs().visit(expr, symbol_map=symbol_map) + return node diff --git a/src/iterator/transforms/inline_lifts.py b/src/iterator/transforms/inline_lifts.py new file mode 100644 index 0000000000..02af36f799 --- /dev/null +++ b/src/iterator/transforms/inline_lifts.py @@ -0,0 +1,37 @@ +from eve import NodeTranslator +from iterator import ir + + +class InlineLifts(NodeTranslator): + def visit_FunCall(self, node: ir.FunCall): + node = self.generic_visit(node) + if isinstance(node.fun, ir.SymRef) and node.fun.id == "deref": + assert len(node.args) == 1 + if ( + isinstance(node.args[0], ir.FunCall) + and isinstance(node.args[0].fun, ir.FunCall) + and isinstance(node.args[0].fun.fun, ir.SymRef) + and node.args[0].fun.fun.id == "lift" + ): + # deref(lift(f)(args...)) -> f(args...) + assert len(node.args[0].fun.args) == 1 + f = node.args[0].fun.args[0] + args = node.args[0].args + return ir.FunCall(fun=f, args=args) + elif ( + isinstance(node.args[0], ir.FunCall) + and isinstance(node.args[0].fun, ir.FunCall) + and isinstance(node.args[0].fun.fun, ir.SymRef) + and node.args[0].fun.fun.id == "shift" + and isinstance(node.args[0].args[0], ir.FunCall) + and isinstance(node.args[0].args[0].fun, ir.FunCall) + and isinstance(node.args[0].args[0].fun.fun, ir.SymRef) + and node.args[0].args[0].fun.fun.id == "lift" + ): + # deref(shift(...)(lift(f)(args...)) -> f(shift(...)(args)...) + f = node.args[0].args[0].fun.args[0] + shift = node.args[0].fun + args = node.args[0].args[0].args + res = ir.FunCall(fun=f, args=[ir.FunCall(fun=shift, args=[arg]) for arg in args]) + return res + return node diff --git a/src/iterator/transforms/normalize_shifts.py b/src/iterator/transforms/normalize_shifts.py new file mode 100644 index 0000000000..500ff232cf --- /dev/null +++ b/src/iterator/transforms/normalize_shifts.py @@ -0,0 +1,25 @@ +from eve import NodeTranslator +from iterator import ir + + +class NormalizeShifts(NodeTranslator): + def visit_FunCall(self, node: ir.FunCall): + node = self.generic_visit(node) + if ( + isinstance(node.fun, ir.FunCall) + and isinstance(node.fun.fun, ir.SymRef) + and node.fun.fun.id == "shift" + and node.args + and isinstance(node.args[0], ir.FunCall) + and isinstance(node.args[0].fun.fun, ir.SymRef) + and node.args[0].fun.fun.id == "shift" + ): + # shift(args1...)(shift(args2...)(it)) -> shift(args2..., args1...)(it) + assert len(node.args) == 1 + return ir.FunCall( + fun=ir.FunCall( + fun=ir.SymRef(id="shift"), args=node.args[0].fun.args + node.fun.args + ), + args=node.args[0].args, + ) + return node diff --git a/src/iterator/transforms/popup_tmps.py b/src/iterator/transforms/popup_tmps.py new file mode 100644 index 0000000000..b67edb5cb3 --- /dev/null +++ b/src/iterator/transforms/popup_tmps.py @@ -0,0 +1,60 @@ +from typing import Dict, Optional + +from eve import NodeTranslator +from iterator import ir +from iterator.transforms.remap_symbols import RemapSymbolRefs + + +class PopupTmps(NodeTranslator): + _counter = 0 + + def visit_FunCall(self, node: ir.FunCall, *, lifts: Optional[Dict[str, ir.Node]] = None): + if ( + isinstance(node.fun, ir.FunCall) + and isinstance(node.fun.fun, ir.SymRef) + and node.fun.fun.id == "lift" + ): + # lifted lambda call + assert len(node.fun.args) == 1 and isinstance(node.fun.args[0], ir.Lambda) + assert lifts is not None + + nested_lifts: Dict[str, ir.Node] = dict() + res = self.generic_visit(node, lifts=nested_lifts) + # TODO: avoid possible symbol name clashes + symbol = f"t{self._counter}" + self._counter += 1 + + symbol_map = {param.id: arg for param, arg in zip(res.fun.args[0].params, res.args)} + new_args = [ + RemapSymbolRefs().visit(arg, symbol_map=symbol_map) for arg in nested_lifts.values() + ] + assert len(res.fun.args[0].params) == len(res.args + new_args) + call = ir.FunCall(fun=res.fun, args=res.args + new_args) + + # return existing definition if the same expression was lifted before + for k, v in lifts.items(): + if call == v: + return ir.SymRef(id=k) + + lifts[symbol] = call + return ir.SymRef(id=symbol) + elif isinstance(node.fun, ir.Lambda): + # direct lambda call + lifts = dict() + res = self.generic_visit(node, lifts=lifts) + symbol_map = {param.id: arg for param, arg in zip(res.fun.params, res.args)} + new_args = [ + RemapSymbolRefs().visit(arg, symbol_map=symbol_map) for arg in lifts.values() + ] + assert len(res.fun.params) == len(res.args + new_args) + return ir.FunCall(fun=res.fun, args=res.args + new_args) + + return self.generic_visit(node, lifts=lifts) + + def visit_Lambda(self, node: ir.Lambda, *, lifts): + node = self.generic_visit(node, lifts=lifts) + if not lifts: + return node + + new_params = [ir.Sym(id=param) for param in lifts.keys()] + return ir.Lambda(params=node.params + new_params, expr=node.expr) diff --git a/src/iterator/transforms/remap_symbols.py b/src/iterator/transforms/remap_symbols.py new file mode 100644 index 0000000000..8a17944b59 --- /dev/null +++ b/src/iterator/transforms/remap_symbols.py @@ -0,0 +1,48 @@ +from typing import Any, Dict, Optional, Set + +from eve import NodeTranslator +from iterator import ir + + +class RemapSymbolRefs(NodeTranslator): + def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): + return symbol_map.get(node.id, node) + + def visit_Lambda(self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node]): + params = {str(p.id) for p in node.params} + new_symbol_map = {k: v for k, v in symbol_map.items() if k not in params} + return ir.Lambda( + params=node.params, + expr=self.generic_visit(node.expr, symbol_map=new_symbol_map), + ) + + def generic_visit(self, node: ir.Node, **kwargs: Any): + assert isinstance(node, ir.SymbolTableTrait) == isinstance( + node, ir.Lambda + ), "found unexpected new symbol scope" + return super().generic_visit(node, **kwargs) + + +class RenameSymbols(NodeTranslator): + def visit_Sym( + self, node: ir.Sym, *, name_map: Dict[str, str], active: Optional[Set[str]] = None + ): + if active and node.id in active: + return ir.Sym(id=name_map.get(node.id, node.id)) + return node + + def visit_SymRef( + self, node: ir.SymRef, *, name_map: Dict[str, str], active: Optional[Set[str]] = None + ): + if active and node.id in active: + return ir.SymRef(id=name_map.get(node.id, node.id)) + return node + + def generic_visit( + self, node: ir.Node, *, name_map: Dict[str, str], active: Optional[Set[str]] = None + ): + if isinstance(node, ir.SymbolTableTrait): + if active is None: + active = set() + active = active | set(node.symtable_) + return super().generic_visit(node, name_map=name_map, active=active) diff --git a/src/iterator/util/__init__.py b/src/iterator/util/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/iterator/util/sym_validation.py b/src/iterator/util/sym_validation.py new file mode 100644 index 0000000000..17af043739 --- /dev/null +++ b/src/iterator/util/sym_validation.py @@ -0,0 +1,49 @@ +from typing import Any, Dict, List, Type + +import pydantic + +from eve import Node +from eve.traits import SymbolTableTrait +from eve.type_definitions import SymbolRef +from eve.typingx import RootValidatorType, RootValidatorValuesType +from eve.visitors import NodeVisitor + + +def validate_symbol_refs() -> RootValidatorType: + """Validate that symbol refs are found in a symbol table valid at the current scope.""" + + def _impl( + cls: Type[pydantic.BaseModel], values: RootValidatorValuesType + ) -> RootValidatorValuesType: + class SymtableValidator(NodeVisitor): + def __init__(self) -> None: + self.missing_symbols: List[str] = [] + + def visit_Node(self, node: Node, *, symtable: Dict[str, Any], **kwargs: Any) -> None: + for name, metadata in node.__node_children__.items(): + if isinstance(metadata["definition"].type_, type) and issubclass( + metadata["definition"].type_, SymbolRef + ): + if getattr(node, name) and getattr(node, name) not in symtable: + self.missing_symbols.append(getattr(node, name)) + + if isinstance(node, SymbolTableTrait): + symtable = {**symtable, **node.symtable_} + self.generic_visit(node, symtable=symtable, **kwargs) + + @classmethod + def apply(cls, node: Node, *, symtable: Dict[str, Any]) -> List[str]: + instance = cls() + instance.visit(node, symtable=symtable) + return instance.missing_symbols + + missing_symbols = [] + for v in values.values(): + missing_symbols.extend(SymtableValidator.apply(v, symtable=values["symtable_"])) + + if len(missing_symbols) > 0: + raise ValueError("Symbols {} not found.".format(missing_symbols)) + + return values + + return pydantic.root_validator(allow_reuse=True, skip_on_failure=True)(_impl) diff --git a/src/iterator/utils.py b/src/iterator/utils.py new file mode 100644 index 0000000000..dd7fa908a9 --- /dev/null +++ b/src/iterator/utils.py @@ -0,0 +1,19 @@ +# GT4Py New Semantic Model - GridTools Framework +# +# Copyright (c) 2014-2021, ETH Zurich All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. GT4Py +# New Semantic Model is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the Free +# Software Foundation, either version 3 of the License, or any later version. +# See the LICENSE.txt file at the top-level directory of this distribution for +# a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +def tupelize(tup): + if isinstance(tup, tuple): + return tup + else: + return (tup,) diff --git a/tests/iterator_tests/__init__.py b/tests/iterator_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/iterator_tests/conftest.py b/tests/iterator_tests/conftest.py new file mode 100644 index 0000000000..07218b13aa --- /dev/null +++ b/tests/iterator_tests/conftest.py @@ -0,0 +1,20 @@ +import pytest + + +@pytest.fixture(params=[False, True], ids=lambda p: f"use_tmps={p}") +def use_tmps(request): + return request.param + + +@pytest.fixture( + params=[ + # (backend, do_validate) + ("lisp", False), + ("cpptoy", False), + ("embedded", True), + ("double_roundtrip", True), + ], + ids=lambda p: f"backend={p[0]}", +) +def backend(request): + return request.param diff --git a/tests/iterator_tests/fvm_nabla_setup.py b/tests/iterator_tests/fvm_nabla_setup.py new file mode 100644 index 0000000000..e4d7379583 --- /dev/null +++ b/tests/iterator_tests/fvm_nabla_setup.py @@ -0,0 +1,211 @@ +# GT4Py New Semantic Model - GridTools Framework +# +# Copyright (c) 2014-2021, ETH Zurich All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. GT4Py +# New Semantic Model is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the Free +# Software Foundation, either version 3 of the License, or any later version. +# See the LICENSE.txt file at the top-level directory of this distribution for +# a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import math + +import numpy as np +from atlas4py import ( + Config, + StructuredGrid, + StructuredMeshGenerator, + Topology, + build_edges, + build_median_dual_mesh, + build_node_to_edge_connectivity, + functionspace, +) + + +def assert_close(expected, actual): + assert math.isclose(expected, actual), "expected={}, actual={}".format(expected, actual) + + +class nabla_setup: + @staticmethod + def _default_config(): + config = Config() + config["triangulate"] = True + config["angle"] = 20.0 + return config + + def __init__(self, *, grid=StructuredGrid("O32"), config=None): + if config is None: + config = self._default_config() + mesh = StructuredMeshGenerator(config).generate(grid) + + fs_edges = functionspace.EdgeColumns(mesh, halo=1) + fs_nodes = functionspace.NodeColumns(mesh, halo=1) + + build_edges(mesh) + build_node_to_edge_connectivity(mesh) + build_median_dual_mesh(mesh) + + edges_per_node = max( + [mesh.nodes.edge_connectivity.cols(node) for node in range(0, fs_nodes.size)] + ) + + self.mesh = mesh + self.fs_edges = fs_edges + self.fs_nodes = fs_nodes + self.edges_per_node = edges_per_node + + @property + def edges2node_connectivity(self): + return self.mesh.edges.node_connectivity + + @property + def nodes2edge_connectivity(self): + return self.mesh.nodes.edge_connectivity + + @property + def nodes_size(self): + return self.fs_nodes.size + + @property + def edges_size(self): + return self.fs_edges.size + + @staticmethod + def _is_pole_edge(e, edge_flags): + return Topology.check(edge_flags[e], Topology.POLE) + + @property + def is_pole_edge_field(self): + edge_flags = np.array(self.mesh.edges.flags()) + + pole_edge_field = np.zeros((self.edges_size,)) + for e in range(self.edges_size): + pole_edge_field[e] = self._is_pole_edge(e, edge_flags) + return edge_flags + + @property + def sign_field(self): + node2edge_sign = np.zeros((self.nodes_size, self.edges_per_node)) + edge_flags = np.array(self.mesh.edges.flags()) + + for jnode in range(0, self.nodes_size): + node_edge_con = self.mesh.nodes.edge_connectivity + edge_node_con = self.mesh.edges.node_connectivity + for jedge in range(0, node_edge_con.cols(jnode)): + iedge = node_edge_con[jnode, jedge] + ip1 = edge_node_con[iedge, 0] + if jnode == ip1: + node2edge_sign[jnode, jedge] = 1.0 + else: + node2edge_sign[jnode, jedge] = -1.0 + if self._is_pole_edge(iedge, edge_flags): + node2edge_sign[jnode, jedge] = 1.0 + return node2edge_sign + + @property + def S_fields(self): + S = np.array(self.mesh.edges.field("dual_normals"), copy=False) + S_MXX = np.zeros((self.edges_size)) + S_MYY = np.zeros((self.edges_size)) + + MXX = 0 + MYY = 1 + + rpi = 2.0 * math.asin(1.0) + radius = 6371.22e03 + deg2rad = 2.0 * rpi / 360.0 + + for i in range(0, self.edges_size): + S_MXX[i] = S[i, MXX] * radius * deg2rad + S_MYY[i] = S[i, MYY] * radius * deg2rad + + assert math.isclose(min(S_MXX), -103437.60479272791) + assert math.isclose(max(S_MXX), 340115.33913622628) + assert math.isclose(min(S_MYY), -2001577.7946404363) + assert math.isclose(max(S_MYY), 2001577.7946404363) + + return S_MXX, S_MYY + + @property + def vol_field(self): + rpi = 2.0 * math.asin(1.0) + radius = 6371.22e03 + deg2rad = 2.0 * rpi / 360.0 + vol_atlas = np.array(self.mesh.nodes.field("dual_volumes"), copy=False) + # dual_volumes 4.6510228700066421 68.891611253882218 12.347560975609632 + assert_close(4.6510228700066421, min(vol_atlas)) + assert_close(68.891611253882218, max(vol_atlas)) + + vol = np.zeros((vol_atlas.size)) + for i in range(0, vol_atlas.size): + vol[i] = vol_atlas[i] * pow(deg2rad, 2) * pow(radius, 2) + # VOL(min/max): 57510668192.214096 851856184496.32886 + assert_close(57510668192.214096, min(vol)) + assert_close(851856184496.32886, max(vol)) + return vol + + @property + def input_field(self): + klevel = 0 + MXX = 0 + MYY = 1 + rpi = 2.0 * math.asin(1.0) + radius = 6371.22e03 + deg2rad = 2.0 * rpi / 360.0 + + zh0 = 2000.0 + zrad = 3.0 * rpi / 4.0 * radius + zeta = rpi / 16.0 * radius + zlatc = 0.0 + zlonc = 3.0 * rpi / 2.0 + + m_rlonlatcr = self.fs_nodes.create_field( + name="m_rlonlatcr", + levels=1, + dtype=np.float64, + variables=self.edges_per_node, + ) + rlonlatcr = np.array(m_rlonlatcr, copy=False) + + m_rcoords = self.fs_nodes.create_field( + name="m_rcoords", levels=1, dtype=np.float64, variables=self.edges_per_node + ) + rcoords = np.array(m_rcoords, copy=False) + + m_rcosa = self.fs_nodes.create_field(name="m_rcosa", levels=1, dtype=np.float64) + rcosa = np.array(m_rcosa, copy=False) + + m_rsina = self.fs_nodes.create_field(name="m_rsina", levels=1, dtype=np.float64) + rsina = np.array(m_rsina, copy=False) + + m_pp = self.fs_nodes.create_field(name="m_pp", levels=1, dtype=np.float64) + rzs = np.array(m_pp, copy=False) + + rcoords_deg = np.array(self.mesh.nodes.field("lonlat")) + + for jnode in range(0, self.nodes_size): + for i in range(0, 2): + rcoords[jnode, klevel, i] = rcoords_deg[jnode, i] * deg2rad + rlonlatcr[jnode, klevel, i] = rcoords[jnode, klevel, i] # This is not my pattern! + rcosa[jnode, klevel] = math.cos(rlonlatcr[jnode, klevel, MYY]) + rsina[jnode, klevel] = math.sin(rlonlatcr[jnode, klevel, MYY]) + for jnode in range(0, self.nodes_size): + zlon = rlonlatcr[jnode, klevel, MXX] + zdist = math.sin(zlatc) * rsina[jnode, klevel] + math.cos(zlatc) * rcosa[ + jnode, klevel + ] * math.cos(zlon - zlonc) + zdist = radius * math.acos(zdist) + rzs[jnode, klevel] = 0.0 + if zdist < zrad: + rzs[jnode, klevel] = rzs[jnode, klevel] + 0.5 * zh0 * ( + 1.0 + math.cos(rpi * zdist / zrad) + ) * math.pow(math.cos(rpi * zdist / zeta), 2) + + assert_close(0.0000000000000000, min(rzs)) + assert_close(1965.4980340735883, max(rzs)) + return rzs[:, klevel] diff --git a/tests/iterator_tests/hdiff_reference.py b/tests/iterator_tests/hdiff_reference.py new file mode 100644 index 0000000000..9788746fde --- /dev/null +++ b/tests/iterator_tests/hdiff_reference.py @@ -0,0 +1,41 @@ +# GT4Py New Semantic Model - GridTools Framework +# +# Copyright (c) 2014-2021, ETH Zurich All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. GT4Py +# New Semantic Model is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the Free +# Software Foundation, either version 3 of the License, or any later version. +# See the LICENSE.txt file at the top-level directory of this distribution for +# a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import numpy as np +import pytest + + +def hdiff_reference_impl(): + shape = (5, 7, 5) + rng = np.random.default_rng() + inp = rng.normal(size=(shape[0] + 4, shape[1] + 4, shape[2])) + coeff = rng.normal(size=shape) + + lap = 4 * inp[1:-1, 1:-1, :] - ( + inp[2:, 1:-1, :] + inp[:-2, 1:-1, :] + inp[1:-1, 2:, :] + inp[1:-1, :-2, :] + ) + uflx = lap[1:, 1:-1, :] - lap[:-1, 1:-1, :] + flx = np.where(uflx * (inp[2:-1, 2:-2, :] - inp[1:-2, 2:-2, :]) > 0, 0, uflx) + ufly = lap[1:-1, 1:, :] - lap[1:-1, :-1, :] + fly = np.where(ufly * (inp[2:-2, 2:-1, :] - inp[2:-2, 1:-2, :]) > 0, 0, ufly) + out = inp[2:-2, 2:-2, :] - coeff * ( + flx[1:, :, :] - flx[:-1, :, :] + fly[:, 1:, :] - fly[:, :-1, :] + ) + + return inp, coeff, out + + +@pytest.fixture +def hdiff_reference(): + return hdiff_reference_impl() diff --git a/tests/iterator_tests/test_anton_toy.py b/tests/iterator_tests/test_anton_toy.py new file mode 100644 index 0000000000..1ba8ba4865 --- /dev/null +++ b/tests/iterator_tests/test_anton_toy.py @@ -0,0 +1,82 @@ +import numpy as np + +from iterator.builtins import deref, domain, lift, named_range, shift +from iterator.embedded import np_as_located_field +from iterator.runtime import CartesianAxis, closure, fendef, fundef, offset + + +@fundef +def ldif(d): + return lambda inp: deref(shift(d, -1)(inp)) - deref(inp) + + +@fundef +def rdif(d): + # return compose(ldif(d), shift(d, 1)) # noqa: E800 + return lambda inp: ldif(d)(shift(d, 1)(inp)) + + +@fundef +def dif2(d): + # return compose(ldif(d), lift(rdif(d))) # noqa: E800 + return lambda inp: ldif(d)(lift(rdif(d))(inp)) + + +i = offset("i") +j = offset("j") + + +@fundef +def lap(inp): + return dif2(i)(inp) + dif2(j)(inp) + + +IDim = CartesianAxis("IDim") +JDim = CartesianAxis("JDim") +KDim = CartesianAxis("KDim") + + +@fendef(offset_provider={"i": IDim, "j": JDim}) +def fencil(x, y, z, out, inp): + closure( + domain(named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z)), + lap, + [out], + [inp], + ) + + +def naive_lap(inp): + shape = [inp.shape[0] - 2, inp.shape[1] - 2, inp.shape[2]] + out = np.zeros(shape) + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(0, shape[2]): + out[i, j, k] = -4 * inp[i, j, k] + ( + inp[i + 1, j, k] + inp[i - 1, j, k] + inp[i, j + 1, k] + inp[i, j - 1, k] + ) + return out + + +def test_anton_toy(backend, use_tmps): + backend, validate = backend + shape = [5, 7, 9] + rng = np.random.default_rng() + inp = np_as_located_field(IDim, JDim, KDim, origin={IDim: 1, JDim: 1, KDim: 0})( + rng.normal(size=(shape[0] + 2, shape[1] + 2, shape[2])), + ) + out = np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)) + ref = naive_lap(inp) + + fencil( + shape[0], + shape[1], + shape[2], + out, + inp, + backend=backend, + use_tmps=use_tmps, + ) + + if validate: + assert np.allclose(out, ref) diff --git a/tests/iterator_tests/test_cartesian_offset_provider.py b/tests/iterator_tests/test_cartesian_offset_provider.py new file mode 100644 index 0000000000..88ee0405a6 --- /dev/null +++ b/tests/iterator_tests/test_cartesian_offset_provider.py @@ -0,0 +1,80 @@ +import numpy as np + +from iterator.builtins import * +from iterator.embedded import np_as_located_field +from iterator.runtime import * + + +I = offset("I") +J = offset("J") +I_loc = CartesianAxis("I_loc") +J_loc = CartesianAxis("J_loc") + + +@fundef +def foo(inp): + return deref(shift(J, 1)(inp)) + + +@fendef(offset_provider={"I": I_loc, "J": J_loc}) +def fencil(output, input): + closure( + domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), + foo, + [output], + [input], + ) + + +@fendef(offset_provider={"I": J_loc, "J": I_loc}) +def fencil_swapped(output, input): + closure( + domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), + foo, + [output], + [input], + ) + + +def test_cartesian_offset_provider(): + inp = np_as_located_field(I_loc, J_loc)(np.asarray([[0, 42], [1, 43]])) + out = np_as_located_field(I_loc, J_loc)(np.asarray([[-1]])) + + fencil(out, inp) + assert out[0][0] == 42 + + fencil_swapped(out, inp) + assert out[0][0] == 1 + + fencil(out, inp, backend="embedded") + assert out[0][0] == 42 + + fencil(out, inp, backend="double_roundtrip") + assert out[0][0] == 42 + + +@fundef +def delay_complete_shift(inp): + return deref(shift(I, J, 1, 1)(inp)) + + +@fendef(offset_provider={"I": J_loc, "J": I_loc}) +def delay_complete_shift_fencil(output, input): + closure( + domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), + delay_complete_shift, + [output], + [input], + ) + + +def test_delay_complete_shift(): + inp = np_as_located_field(I_loc, J_loc)(np.asarray([[0, 42], [1, 43]])) + + out = np_as_located_field(I_loc, J_loc)(np.asarray([[-1]])) + delay_complete_shift_fencil(out, inp) + assert out[0, 0] == 43 + + out = np_as_located_field(I_loc, J_loc)(np.asarray([[-1]])) + delay_complete_shift_fencil(out, inp, backend="embedded", debug=True) + assert out[0, 0] == 43 diff --git a/tests/iterator_tests/test_column_stencil.py b/tests/iterator_tests/test_column_stencil.py new file mode 100644 index 0000000000..6e87424103 --- /dev/null +++ b/tests/iterator_tests/test_column_stencil.py @@ -0,0 +1,159 @@ +import numpy as np +import pytest + +from iterator.builtins import * +from iterator.embedded import np_as_located_field +from iterator.runtime import * + + +I = offset("I") +K = offset("K") + + +@fundef +def multiply_stencil(inp): + return deref(shift(K, 1, I, 1)(inp)) + + +KDim = CartesianAxis("KDim") +IDim = CartesianAxis("IDim") + + +@fendef(column_axis=KDim) +def fencil(i_size, k_size, inp, out): + closure( + domain(named_range(IDim, 0, i_size), named_range(KDim, 0, k_size)), + multiply_stencil, + [out], + [inp], + ) + + +def test_column_stencil(backend, use_tmps): + backend, validate = backend + shape = [5, 7] + inp = np_as_located_field(IDim, KDim)( + np.fromfunction(lambda i, k: i * 10 + k, [shape[0] + 1, shape[1] + 1]) + ) + out = np_as_located_field(IDim, KDim)(np.zeros(shape)) + + ref = np.asarray(inp)[1:, 1:] + + fencil( + shape[0], + shape[1], + inp, + out, + offset_provider={"I": IDim, "K": KDim}, + backend=backend, + use_tmps=use_tmps, + ) + + if validate: + assert np.allclose(ref, out) + + +def test_column_stencil_with_k_origin(backend, use_tmps): + backend, validate = backend + shape = [5, 7] + raw_inp = np.fromfunction(lambda i, k: i * 10 + k, [shape[0] + 1, shape[1] + 2]) + inp = np_as_located_field(IDim, KDim, origin={IDim: 0, KDim: 1})(raw_inp) + out = np_as_located_field(IDim, KDim)(np.zeros(shape)) + + ref = np.asarray(inp)[1:, 2:] + + fencil( + shape[0], + shape[1], + inp, + out, + offset_provider={"I": IDim, "K": KDim}, + backend=backend, + use_tmps=use_tmps, + ) + + if validate: + assert np.allclose(ref, out) + + +@fundef +def sum_scanpass(state, inp): + return if_(is_none(state), deref(inp), state + deref(inp)) + + +@fundef +def ksum(inp): + return scan(sum_scanpass, True, None)(inp) + + +@fendef(column_axis=KDim) +def ksum_fencil(i_size, k_size, inp, out): + closure( + domain(named_range(IDim, 0, i_size), named_range(KDim, 0, k_size)), + ksum, + [out], + [inp], + ) + + +def test_ksum_scan(backend, use_tmps): + if use_tmps: + pytest.xfail("use_tmps currently not supported for scans") + backend, validate = backend + shape = [1, 7] + inp = np_as_located_field(IDim, KDim)(np.asarray([list(range(7))])) + out = np_as_located_field(IDim, KDim)(np.zeros(shape)) + + ref = np.asarray([[0, 1, 3, 6, 10, 15, 21]]) + + ksum_fencil( + shape[0], + shape[1], + inp, + out, + offset_provider={"I": IDim, "K": KDim}, + backend=backend, + use_tmps=use_tmps, + ) + + if validate: + assert np.allclose(ref, np.asarray(out)) + + +@fundef +def ksum_back(inp): + return scan(sum_scanpass, False, None)(inp) + + +@fendef(column_axis=KDim) +def ksum_back_fencil(i_size, k_size, inp, out): + closure( + domain(named_range(IDim, 0, i_size), named_range(KDim, 0, k_size)), + ksum_back, + [out], + [inp], + ) + + +def test_ksum_back_scan(backend, use_tmps): + if use_tmps: + pytest.xfail("use_tmps currently not supported for scans") + backend, validate = backend + shape = [1, 7] + inp = np_as_located_field(IDim, KDim)(np.asarray([list(range(7))])) + out = np_as_located_field(IDim, KDim)(np.zeros(shape)) + + ref = np.asarray([[21, 21, 20, 18, 15, 11, 6]]) + + ksum_back_fencil( + shape[0], + shape[1], + inp, + out, + offset_provider={"I": IDim, "K": KDim}, + backend=backend, + use_tmps=use_tmps, + ) + + if validate: + assert np.allclose(ref, np.asarray(out)) diff --git a/tests/iterator_tests/test_fvm_nabla.py b/tests/iterator_tests/test_fvm_nabla.py new file mode 100644 index 0000000000..4ab209a4e7 --- /dev/null +++ b/tests/iterator_tests/test_fvm_nabla.py @@ -0,0 +1,247 @@ +# GT4Py New Semantic Model - GridTools Framework +# +# Copyright (c) 2014-2021, ETH Zurich All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. GT4Py +# New Semantic Model is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the Free +# Software Foundation, either version 3 of the License, or any later version. +# See the LICENSE.txt file at the top-level directory of this distribution for +# a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np +import pytest + +from iterator import library +from iterator.atlas_utils import AtlasTable +from iterator.builtins import * +from iterator.embedded import NeighborTableOffsetProvider, index_field, np_as_located_field +from iterator.runtime import * + +from .fvm_nabla_setup import assert_close, nabla_setup + + +Vertex = CartesianAxis("Vertex") +Edge = CartesianAxis("Edge") + +V2E = offset("V2E") +E2V = offset("E2V") + + +@fundef +def compute_zavgS(pp, S_M): + zavg = 0.5 * (deref(shift(E2V, 0)(pp)) + deref(shift(E2V, 1)(pp))) + # zavg = 0.5 * reduce(lambda a, b: a + b, 0)(shift(E2V)(pp)) + # zavg = 0.5 * library.sum()(shift(E2V)(pp)) + return deref(S_M) * zavg + + +@fendef +def compute_zavgS_fencil( + n_edges, + out, + pp, + S_M, +): + closure( + domain(named_range(Edge, 0, n_edges)), + compute_zavgS, + [out], + [pp, S_M], + ) + + +@fundef +def compute_pnabla(pp, S_M, sign, vol): + zavgS = lift(compute_zavgS)(pp, S_M) + # pnabla_M = reduce(lambda a, b, c: a + b * c, 0)(shift(V2E)(zavgS), sign) + # pnabla_M = library.sum(lambda a, b: a * b)(shift(V2E)(zavgS), sign) + pnabla_M = library.dot(shift(V2E)(zavgS), sign) + return pnabla_M / deref(vol) + + +@fendef +def nabla( + n_nodes, + out_MXX, + out_MYY, + pp, + S_MXX, + S_MYY, + sign, + vol, +): + # TODO replace by single stencil which returns tuple + closure( + domain(named_range(Vertex, 0, n_nodes)), + compute_pnabla, + [out_MXX], + [pp, S_MXX, sign, vol], + ) + closure( + domain(named_range(Vertex, 0, n_nodes)), + compute_pnabla, + [out_MYY], + [pp, S_MYY, sign, vol], + ) + + +def test_compute_zavgS(backend, use_tmps): + if use_tmps: + pytest.xfail("use_tmps currently only supported for cartesian") + backend, validate = backend + setup = nabla_setup() + + pp = np_as_located_field(Vertex)(setup.input_field) + S_MXX, S_MYY = tuple(map(np_as_located_field(Edge), setup.S_fields)) + + zavgS = np_as_located_field(Edge)(np.zeros((setup.edges_size))) + + e2v = NeighborTableOffsetProvider(AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2) + + compute_zavgS_fencil( + setup.edges_size, + zavgS, + pp, + S_MXX, + offset_provider={"E2V": e2v}, + ) + + if validate: + assert_close(-199755464.25741270, min(zavgS)) + assert_close(388241977.58389181, max(zavgS)) + + compute_zavgS_fencil( + setup.edges_size, + zavgS, + pp, + S_MYY, + offset_provider={"E2V": e2v}, + ) + if validate: + assert_close(-1000788897.3202186, min(zavgS)) + assert_close(1000788897.3202186, max(zavgS)) + + +def test_nabla(backend, use_tmps): + if use_tmps: + pytest.xfail("use_tmps currently only supported for cartesian") + backend, validate = backend + setup = nabla_setup() + + sign = np_as_located_field(Vertex, V2E)(setup.sign_field) + pp = np_as_located_field(Vertex)(setup.input_field) + S_MXX, S_MYY = tuple(map(np_as_located_field(Edge), setup.S_fields)) + vol = np_as_located_field(Vertex)(setup.vol_field) + + pnabla_MXX = np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) + pnabla_MYY = np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) + + e2v = NeighborTableOffsetProvider(AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2) + v2e = NeighborTableOffsetProvider(AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7) + + nabla( + setup.nodes_size, + pnabla_MXX, + pnabla_MYY, + pp, + S_MXX, + S_MYY, + sign, + vol, + offset_provider={"E2V": e2v, "V2E": v2e}, + backend=backend, + use_tmps=use_tmps, + ) + + if validate: + assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) + assert_close(3.5455427772565435e-003, max(pnabla_MXX)) + assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) + assert_close(3.3540113705465301e-003, max(pnabla_MYY)) + + +@fundef +def sign(node_indices, is_pole_edge): + node_index = deref(node_indices) + + @fundef + def sign_impl(node_index): + def impl2(node_indices, is_pole_edge): + return if_( + or_(deref(is_pole_edge), eq(node_index, deref(shift(E2V, 0)(node_indices)))), + 1.0, + -1.0, + ) + + return impl2 + + return shift(V2E)(lift(sign_impl(node_index))(node_indices, is_pole_edge)) + + +@fundef +def compute_pnabla_sign(pp, S_M, vol, node_index, is_pole_edge): + zavgS = lift(compute_zavgS)(pp, S_M) + pnabla_M = library.dot(shift(V2E)(zavgS), sign(node_index, is_pole_edge)) + + return pnabla_M / deref(vol) + + +@fendef +def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_pole_edge): + # TODO replace by single stencil which returns tuple + closure( + domain(named_range(Vertex, 0, n_nodes)), + compute_pnabla_sign, + [out_MXX], + [pp, S_MXX, vol, node_index, is_pole_edge], + ) + closure( + domain(named_range(Vertex, 0, n_nodes)), + compute_pnabla_sign, + [out_MYY], + [pp, S_MYY, vol, node_index, is_pole_edge], + ) + + +def test_nabla_sign(backend, use_tmps): + if use_tmps: + pytest.xfail("use_tmps currently only supported for cartesian") + + backend, validate = backend + setup = nabla_setup() + + # sign = np_as_located_field(Vertex, V2E)(setup.sign_field) + is_pole_edge = np_as_located_field(Edge)(setup.is_pole_edge_field) + pp = np_as_located_field(Vertex)(setup.input_field) + S_MXX, S_MYY = tuple(map(np_as_located_field(Edge), setup.S_fields)) + vol = np_as_located_field(Vertex)(setup.vol_field) + + pnabla_MXX = np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) + pnabla_MYY = np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) + + e2v = NeighborTableOffsetProvider(AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2) + v2e = NeighborTableOffsetProvider(AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7) + + nabla_sign( + setup.nodes_size, + pnabla_MXX, + pnabla_MYY, + pp, + S_MXX, + S_MYY, + vol, + index_field(Vertex), + is_pole_edge, + offset_provider={"E2V": e2v, "V2E": v2e}, + backend=backend, + use_tmps=use_tmps, + ) + + if validate: + assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) + assert_close(3.5455427772565435e-003, max(pnabla_MXX)) + assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) + assert_close(3.3540113705465301e-003, max(pnabla_MYY)) diff --git a/tests/iterator_tests/test_hdiff.py b/tests/iterator_tests/test_hdiff.py new file mode 100644 index 0000000000..56688f6195 --- /dev/null +++ b/tests/iterator_tests/test_hdiff.py @@ -0,0 +1,69 @@ +import numpy as np + +from iterator.builtins import * +from iterator.embedded import np_as_located_field +from iterator.runtime import * + +from .hdiff_reference import hdiff_reference + + +I = offset("I") +J = offset("J") + +IDim = CartesianAxis("IDim") +JDim = CartesianAxis("JDim") + + +@fundef +def laplacian(inp): + return -4.0 * deref(inp) + ( + deref(shift(I, 1)(inp)) + + deref(shift(I, -1)(inp)) + + deref(shift(J, 1)(inp)) + + deref(shift(J, -1)(inp)) + ) + + +@fundef +def flux(d): + def flux_impl(inp): + lap = lift(laplacian)(inp) + flux = deref(lap) - deref(shift(d, 1)(lap)) + return if_(flux * (deref(shift(d, 1)(inp)) - deref(inp)) > 0.0, 0.0, flux) + + return flux_impl + + +@fundef +def hdiff_sten(inp, coeff): + flx = lift(flux(I))(inp) + fly = lift(flux(J))(inp) + return deref(inp) - ( + deref(coeff) + * (deref(flx) - deref(shift(I, -1)(flx)) + deref(fly) - deref(shift(J, -1)(fly))) + ) + + +@fendef(offset_provider={"I": IDim, "J": JDim}) +def hdiff(inp, coeff, out, x, y): + closure( + domain(named_range(IDim, 0, x), named_range(JDim, 0, y)), + hdiff_sten, + [out], + [inp, coeff], + ) + + +def test_hdiff(hdiff_reference, backend, use_tmps): + backend, validate = backend + inp, coeff, out = hdiff_reference + shape = (out.shape[0], out.shape[1]) + + inp_s = np_as_located_field(IDim, JDim, origin={IDim: 2, JDim: 2})(inp[:, :, 0]) + coeff_s = np_as_located_field(IDim, JDim)(coeff[:, :, 0]) + out_s = np_as_located_field(IDim, JDim)(np.zeros_like(coeff[:, :, 0])) + + hdiff(inp_s, coeff_s, out_s, shape[0], shape[1], backend=backend, use_tmps=use_tmps) + + if validate: + assert np.allclose(out[:, :, 0], out_s) diff --git a/tests/iterator_tests/test_horizontal_indirection.py b/tests/iterator_tests/test_horizontal_indirection.py new file mode 100644 index 0000000000..ac5496fe57 --- /dev/null +++ b/tests/iterator_tests/test_horizontal_indirection.py @@ -0,0 +1,55 @@ +# (defun calc (p_vn input_on_cell) +# (do_some_math +# (deref +# ((if (less (deref p_vn) 0) +# (shift e2c 0) +# (shift e2c 1) +# ) +# input_on_cell +# ) +# ) +# ) +# ) +import numpy as np +from numpy.core.numeric import allclose + +from iterator.builtins import * +from iterator.embedded import np_as_located_field +from iterator.runtime import * + + +I = offset("I") + + +@fundef +def compute_shift(cond): + return if_(deref(cond) < 0, shift(I, -1), shift(I, 1)) + + +@fundef +def foo(inp, cond): + return deref(compute_shift(cond)(inp)) + + +@fendef +def fencil(size, inp, cond, out): + closure(domain(named_range(IDim, 0, size)), foo, [out], [inp, cond]) + + +IDim = CartesianAxis("IDim") + + +def test_simple_indirection(): + shape = [8] + inp = np_as_located_field(IDim, origin={IDim: 1})(np.asarray(range(shape[0] + 2))) + rng = np.random.default_rng() + cond = np_as_located_field(IDim)(rng.normal(size=shape)) + out = np_as_located_field(IDim)(np.zeros(shape)) + + ref = np.zeros(shape) + for i in range(shape[0]): + ref[i] = inp[i - 1] if cond[i] < 0 else inp[i + 1] + + fencil(shape[0], inp, cond, out, offset_provider={"I": IDim}) + + assert allclose(ref, out) diff --git a/tests/iterator_tests/test_popup_tmps.py b/tests/iterator_tests/test_popup_tmps.py new file mode 100644 index 0000000000..012f1b085d --- /dev/null +++ b/tests/iterator_tests/test_popup_tmps.py @@ -0,0 +1,156 @@ +from iterator import ir +from iterator.transforms.popup_tmps import PopupTmps + + +def test_trivial_single_lift(): + testee = ir.FunCall( + fun=ir.Lambda( + params=[ir.Sym(id="bar_inp")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ + ir.FunCall( + fun=ir.FunCall( + fun=ir.SymRef(id="lift"), + args=[ + ir.Lambda( + params=[ir.Sym(id="foo_inp")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ir.SymRef(id="foo_inp")], + ), + ) + ], + ), + args=[ir.SymRef(id="bar_inp")], + ) + ], + ), + ), + args=[ir.SymRef(id="inp")], + ) + expected = ir.FunCall( + fun=ir.Lambda( + params=[ir.Sym(id="bar_inp"), ir.Sym(id="t0")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ir.SymRef(id="t0")], + ), + ), + args=[ + ir.SymRef(id="inp"), + ir.FunCall( + fun=ir.FunCall( + fun=ir.SymRef(id="lift"), + args=[ + ir.Lambda( + params=[ir.Sym(id="foo_inp")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ir.SymRef(id="foo_inp")], + ), + ) + ], + ), + args=[ir.SymRef(id="inp")], + ), + ], + ) + actual = PopupTmps().visit(testee) + assert actual == expected + + +def test_trivial_multiple_lifts(): + testee = ir.FunCall( + fun=ir.Lambda( + params=[ir.SymRef(id="baz_inp")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ + ir.FunCall( + fun=ir.FunCall( + fun=ir.SymRef(id="lift"), + args=[ + ir.Lambda( + params=[ir.Sym(id="bar_inp")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ + ir.FunCall( + fun=ir.FunCall( + fun=ir.SymRef(id="lift"), + args=[ + ir.Lambda( + params=[ir.Sym(id="foo_inp")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ir.SymRef(id="foo_inp")], + ), + ) + ], + ), + args=[ir.SymRef(id="bar_inp")], + ) + ], + ), + ) + ], + ), + args=[ir.SymRef(id="baz_inp")], + ) + ], + ), + ), + args=[ir.SymRef(id="inp")], + ) + expected = ir.FunCall( + fun=ir.Lambda( + params=[ir.SymRef(id="baz_inp"), ir.SymRef(id="t1")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ir.SymRef(id="t1")], + ), + ), + args=[ + ir.SymRef(id="inp"), + ir.FunCall( + fun=ir.FunCall( + fun=ir.SymRef(id="lift"), + args=[ + ir.Lambda( + params=[ + ir.Sym(id="bar_inp"), + ir.Sym(id="t0"), + ], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ + ir.SymRef(id="t0"), + ], + ), + ) + ], + ), + args=[ + ir.SymRef(id="inp"), + ir.FunCall( + fun=ir.FunCall( + fun=ir.SymRef(id="lift"), + args=[ + ir.Lambda( + params=[ir.Sym(id="foo_inp")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ir.SymRef(id="foo_inp")], + ), + ) + ], + ), + args=[ir.SymRef(id="inp")], + ), + ], + ), + ], + ) + actual = PopupTmps().visit(testee) + assert actual == expected diff --git a/tests/iterator_tests/test_toy_connectivity.py b/tests/iterator_tests/test_toy_connectivity.py new file mode 100644 index 0000000000..28c47f19fc --- /dev/null +++ b/tests/iterator_tests/test_toy_connectivity.py @@ -0,0 +1,391 @@ +from dataclasses import field + +import numpy as np +from numpy.core.numeric import allclose + +from iterator.builtins import * +from iterator.embedded import NeighborTableOffsetProvider, index_field, np_as_located_field +from iterator.runtime import * + + +Vertex = CartesianAxis("Vertex") +Edge = CartesianAxis("Edge") +Cell = CartesianAxis("Cell") + + +# 3x3 periodic edges cells +# 0 - 1 - 2 - 0 1 2 +# | | | 9 10 11 0 1 2 +# 3 - 4 - 5 - 3 4 5 +# | | | 12 13 14 3 4 5 +# 6 - 7 - 8 - 6 7 8 +# | | | 15 16 17 6 7 8 + + +c2e_arr = np.array( + [ + [0, 10, 3, 9], # 0 + [1, 11, 4, 10], + [2, 9, 5, 11], + [3, 13, 6, 12], # 3 + [4, 14, 7, 13], + [5, 12, 8, 14], + [6, 16, 0, 15], # 6 + [7, 17, 1, 16], + [8, 15, 2, 17], + ] +) + +v2v_arr = np.array( + [ + [1, 3, 2, 6], + [2, 3, 0, 7], + [0, 5, 1, 8], + [4, 6, 5, 0], + [5, 7, 3, 1], + [3, 8, 4, 2], + [7, 0, 8, 3], + [8, 1, 6, 4], + [6, 2, 7, 5], + ] +) + +e2v_arr = np.array( + [ + [0, 1], + [1, 2], + [2, 0], + [3, 4], + [4, 5], + [5, 3], + [6, 7], + [7, 8], + [8, 6], + [0, 3], + [1, 4], + [2, 5], + [3, 6], + [4, 7], + [5, 8], + [6, 0], + [7, 1], + [8, 2], + ] +) + + +# order east, north, west, south (counter-clock wise) +v2e_arr = np.array( + [ + [0, 15, 2, 9], # 0 + [1, 16, 0, 10], + [2, 17, 1, 11], + [3, 9, 5, 12], # 3 + [4, 10, 3, 13], + [5, 11, 4, 14], + [6, 12, 8, 15], # 6 + [7, 13, 6, 16], + [8, 14, 7, 17], + ] +) + +V2E = offset("V2E") +E2V = offset("E2V") +C2E = offset("C2E") + + +@fundef +def sum_edges_to_vertices(in_edges): + return ( + deref(shift(V2E, 0)(in_edges)) + + deref(shift(V2E, 1)(in_edges)) + + deref(shift(V2E, 2)(in_edges)) + + deref(shift(V2E, 3)(in_edges)) + ) + + +@fendef(offset_provider={"V2E": NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}) +def e2v_sum_fencil(in_edges, out_vertices): + closure( + domain(named_range(Vertex, 0, 9)), + sum_edges_to_vertices, + [out_vertices], + [in_edges], + ) + + +def test_sum_edges_to_vertices(backend): + backend, validate = backend + inp = index_field(Edge) + out = np_as_located_field(Vertex)(np.zeros([9])) + ref = np.asarray(list(sum(row) for row in v2e_arr)) + + e2v_sum_fencil(inp, out, backend=backend) + if validate: + assert allclose(out, ref) + + +@fundef +def sum_edges_to_vertices_reduce(in_edges): + return reduce(lambda a, b: a + b, 0)(shift(V2E)(in_edges)) + + +@fendef(offset_provider={"V2E": NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}) +def e2v_sum_fencil_reduce(in_edges, out_vertices): + closure( + domain(named_range(Vertex, 0, 9)), + sum_edges_to_vertices_reduce, + [out_vertices], + [in_edges], + ) + + +def test_sum_edges_to_vertices_reduce(backend): + backend, validate = backend + inp = index_field(Edge) + out = np_as_located_field(Vertex)(np.zeros([9])) + ref = np.asarray(list(sum(row) for row in v2e_arr)) + + e2v_sum_fencil_reduce(inp, out, backend=backend) + if validate: + assert allclose(out, ref) + + +@fundef +def first_vertex_neigh_of_first_edge_neigh_of_cells(in_vertices): + return deref(shift(E2V, 0)(shift(C2E, 0)(in_vertices))) + + +@fendef( + offset_provider={ + "E2V": NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), + "C2E": NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4), + } +) +def first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(in_vertices, out_cells): + closure( + domain(named_range(Cell, 0, 9)), + first_vertex_neigh_of_first_edge_neigh_of_cells, + [out_cells], + [in_vertices], + ) + + +def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(): + inp = index_field(Vertex) + out = np_as_located_field(Cell)(np.zeros([9])) + ref = np.asarray(list(v2e_arr[c[0]][0] for c in c2e_arr)) + + first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(inp, out, backend="double_roundtrip") + assert allclose(out, ref) + + +@fundef +def sparse_stencil(inp): + return reduce(lambda a, b: a + b, 0)(inp) + + +@fendef +def sparse_fencil(inp, out): + closure( + domain(named_range(Vertex, 0, 9)), + sparse_stencil, + [out], + [inp], + ) + + +def test_sparse_input_field(backend): + backend, validate = backend + inp = np_as_located_field(Vertex, V2E)(np.asarray([[1, 2, 3, 4]] * 9)) + out = np_as_located_field(Vertex)(np.zeros([9])) + + ref = np.ones([9]) * 10 + + sparse_fencil( + inp, + out, + backend=backend, + offset_provider={"V2E": NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + ) + + if validate: + assert allclose(out, ref) + + +V2V = offset("V2V") + + +def test_sparse_input_field_v2v(backend): + backend, validate = backend + inp = np_as_located_field(Vertex, V2V)(v2v_arr) + out = np_as_located_field(Vertex)(np.zeros([9])) + + ref = np.asarray(list(sum(row) for row in v2v_arr)) + + sparse_fencil( + inp, + out, + backend=backend, + offset_provider={"V2V": NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + ) + + if validate: + assert allclose(out, ref) + + +@fundef +def deref_stencil(inp): + return deref(shift(V2V, 0)(inp)) + + +@fundef +def lift_stencil(inp): + return deref(shift(V2V, 2)(lift(deref_stencil)(inp))) + + +@fendef +def lift_fencil(inp, out): + closure(domain(named_range(Vertex, 0, 9)), lift_stencil, [out], [inp]) + + +def test_lift(backend): + backend, validate = backend + inp = index_field(Vertex) + out = np_as_located_field(Vertex)(np.zeros([9])) + ref = np.asarray(np.asarray(range(9))) + + lift_fencil(None, None, backend="cpptoy") + lift_fencil( + inp, + out, + backend=backend, + offset_provider={"V2V": NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + ) + if validate: + assert allclose(out, ref) + + +@fundef +def sparse_shifted_stencil(inp): + return deref(shift(0, 2)(shift(V2V)(inp))) + + +@fendef +def sparse_shifted_fencil(inp, out): + closure( + domain(named_range(Vertex, 0, 9)), + sparse_shifted_stencil, + [out], + [inp], + ) + + +def test_shift_sparse_input_field(backend): + backend, validate = backend + inp = np_as_located_field(Vertex, V2V)(v2v_arr) + out = np_as_located_field(Vertex)(np.zeros([9])) + ref = np.asarray(np.asarray(range(9))) + + sparse_shifted_fencil( + inp, + out, + backend=backend, + offset_provider={"V2V": NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + ) + + if validate: + assert allclose(out, ref) + + +@fundef +def shift_shift_stencil2(inp): + return deref(shift(E2V, 1)(shift(V2E, 3)(inp))) + + +@fundef +def shift_sparse_stencil2(inp): + return deref(shift(1, 3)(shift(V2E)(inp))) + + +@fendef +def sparse_shifted_fencil2(inp_sparse, inp, out1, out2): + closure( + domain(named_range(Vertex, 0, 9)), + shift_shift_stencil2, + [out1], + [inp], + ) + closure( + domain(named_range(Vertex, 0, 9)), + shift_sparse_stencil2, + [out2], + [inp_sparse], + ) + + +def test_shift_sparse_input_field2(backend): + backend, validate = backend + inp = index_field(Vertex) + inp_sparse = np_as_located_field(Edge, E2V)(e2v_arr) + out1 = np_as_located_field(Vertex)(np.zeros([9])) + out2 = np_as_located_field(Vertex)(np.zeros([9])) + + sparse_shifted_fencil2( + inp_sparse, + inp, + out1, + out2, + backend=backend, + offset_provider={ + "E2V": NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), + "V2E": NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), + }, + ) + + if validate: + assert allclose(out1, out2) + + +@fundef +def sparse_shifted_stencil_reduce(inp): + def sum_(a, b): + return a + b + + # return deref(shift(V2V, 0)(lift(deref)(shift(0)(inp)))) + return reduce(sum_, 0)(shift(V2V)(lift(reduce(sum_, 0))(inp))) + + +@fendef +def sparse_shifted_fencil_reduce(inp, out): + closure( + domain(named_range(Vertex, 0, 9)), + sparse_shifted_stencil_reduce, + [out], + [inp], + ) + + +def test_shift_sparse_input_field(backend): + backend, validate = backend + inp = np_as_located_field(Vertex, V2V)(v2v_arr) + out = np_as_located_field(Vertex)(np.zeros([9])) + + ref = [] + for row in v2v_arr: + elem_sum = 0 + for neigh in row: + elem_sum += sum(v2v_arr[neigh]) + ref.append(elem_sum) + + ref = np.asarray(ref) + + sparse_shifted_fencil_reduce( + inp, + out, + backend=backend, + offset_provider={"V2V": NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + ) + + if validate: + assert allclose(np.asarray(out), ref) diff --git a/tests/iterator_tests/test_trivial.py b/tests/iterator_tests/test_trivial.py new file mode 100644 index 0000000000..a31c110e19 --- /dev/null +++ b/tests/iterator_tests/test_trivial.py @@ -0,0 +1,53 @@ +import numpy as np + +from iterator.builtins import * +from iterator.embedded import np_as_located_field +from iterator.runtime import * + + +I = offset("I") +J = offset("J") + +IDim = CartesianAxis("IDim") +JDim = CartesianAxis("JDim") + + +@fundef +def foo(foo_inp): + return deref(foo_inp) + + +@fundef +def bar(bar_inp): + return deref(lift(foo)(bar_inp)) + + +@fundef +def baz(baz_inp): + return deref(lift(bar)(baz_inp)) + + +@fendef(offset_provider={"I": IDim, "J": JDim}) +def foobarbaz(inp, out, x, y): + closure( + domain(named_range(IDim, 0, x), named_range(JDim, 0, y)), + baz, + [out], + [inp], + ) + + +def test_trivial(backend, use_tmps): + backend, validate = backend + rng = np.random.default_rng() + inp = rng.uniform(size=(5, 7, 9)) + out = np.copy(inp) + shape = (out.shape[0], out.shape[1]) + + inp_s = np_as_located_field(IDim, JDim, origin={IDim: 0, JDim: 0})(inp[:, :, 0]) + out_s = np_as_located_field(IDim, JDim)(np.zeros_like(inp[:, :, 0])) + + foobarbaz(inp_s, out_s, inp.shape[0], inp.shape[1], backend=backend, use_tmps=use_tmps) + + if validate: + assert np.allclose(out[:, :, 0], out_s) diff --git a/tests/iterator_tests/test_vertical_advection.py b/tests/iterator_tests/test_vertical_advection.py new file mode 100644 index 0000000000..76349bd1cf --- /dev/null +++ b/tests/iterator_tests/test_vertical_advection.py @@ -0,0 +1,129 @@ +import numpy as np +import pytest + +from iterator.builtins import * +from iterator.embedded import np_as_located_field +from iterator.runtime import * + + +@fundef +def tridiag_forward(state, a, b, c, d): + # not tracable + # if is_none(state): + # cp_k = deref(c) / deref(b) + # dp_k = deref(d) / deref(b) + # else: + # cp_km1, dp_km1 = state + # cp_k = deref(c) / (deref(b) - deref(a) * cp_km1) + # dp_k = (deref(d) - deref(a) * dp_km1) / (deref(b) - deref(a) * cp_km1) + # return make_tuple(cp_k, dp_k) + + # variant a + # return if_( + # is_none(state), + # make_tuple(deref(c) / deref(b), deref(d) / deref(b)), + # make_tuple( + # deref(c) / (deref(b) - deref(a) * nth(0, state)), + # (deref(d) - deref(a) * nth(1, state)) + # / (deref(b) - deref(a) * nth(0, state)), + # ), + # ) + + # variant b + def initial(): + return make_tuple(deref(c) / deref(b), deref(d) / deref(b)) + + def step(): + return make_tuple( + deref(c) / (deref(b) - deref(a) * nth(0, state)), + (deref(d) - deref(a) * nth(1, state)) / (deref(b) - deref(a) * nth(0, state)), + ) + + return if_(is_none(state), initial, step)() + + +@fundef +def tridiag_backward(x_kp1, cp, dp): + # if is_none(x_kp1): + # x_k = deref(dp) + # else: + # x_k = deref(dp) - deref(cp) * x_kp1 + # return x_k + return if_(is_none(x_kp1), deref(dp), deref(dp) - deref(cp) * x_kp1) + + +@fundef +def solve_tridiag(a, b, c, d): + tup = lift(scan(tridiag_forward, True, None))(a, b, c, d) + cp = nth(0, tup) + dp = nth(1, tup) + return scan(tridiag_backward, False, None)(cp, dp) + + +@pytest.fixture +def tridiag_reference(): + shape = (3, 7, 5) + rng = np.random.default_rng() + a = rng.normal(size=shape) + b = rng.normal(size=shape) * 2 + c = rng.normal(size=shape) + d = rng.normal(size=shape) + + matrices = np.zeros(shape + shape[-1:]) + i = np.arange(shape[2]) + matrices[:, :, i[1:], i[:-1]] = a[:, :, 1:] + matrices[:, :, i, i] = b + matrices[:, :, i[:-1], i[1:]] = c[:, :, :-1] + x = np.linalg.solve(matrices, d) + return a, b, c, d, x + + +IDim = CartesianAxis("IDim") +JDim = CartesianAxis("JDim") +KDim = CartesianAxis("KDim") + + +@fendef +def fen_solve_tridiag(i_size, j_size, k_size, a, b, c, d, x): + closure( + domain( + named_range(IDim, 0, i_size), + named_range(JDim, 0, j_size), + named_range(KDim, 0, k_size), + ), + solve_tridiag, + [x], + [a, b, c, d], + ) + + +def test_tridiag(tridiag_reference, backend, use_tmps): + if use_tmps: + pytest.xfail("use_tmps currently not supported for scans") + backend, validate = backend + a, b, c, d, x = tridiag_reference + shape = a.shape + as_3d_field = np_as_located_field(IDim, JDim, KDim) + a_s = as_3d_field(a) + b_s = as_3d_field(b) + c_s = as_3d_field(c) + d_s = as_3d_field(d) + x_s = as_3d_field(np.zeros_like(x)) + + fen_solve_tridiag( + shape[0], + shape[1], + shape[2], + a_s, + b_s, + c_s, + d_s, + x_s, + offset_provider={}, + column_axis=KDim, + backend=backend, + use_tmps=use_tmps, + ) + + if validate: + assert np.allclose(x, np.asarray(x_s))