From e8165ec0e172b617af8aff8f844a9e6fc008d623 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 4 Aug 2021 15:33:52 +0200 Subject: [PATCH 1/6] Move prototype from gt4py_new_model repository https://github.com/fthaler/gt4py_new_model/tree/iterator_v2 --- src/iterator/ARCHITECTURE.md | 60 ++ src/iterator/README.md | 1 + src/iterator/__init__.py | 20 + src/iterator/atlas_utils.py | 33 + src/iterator/backend_executor.py | 20 + src/iterator/backends/__init__.py | 1 + src/iterator/backends/backend.py | 9 + src/iterator/backends/cpptoy.py | 49 ++ src/iterator/backends/double_roundtrip.py | 9 + src/iterator/backends/embedded.py | 103 ++++ src/iterator/backends/lisp.py | 60 ++ src/iterator/builtins.py | 113 ++++ src/iterator/dispatcher.py | 56 ++ src/iterator/embedded.py | 570 ++++++++++++++++++ src/iterator/ir.py | 116 ++++ src/iterator/library.py | 15 + src/iterator/runtime.py | 84 +++ src/iterator/tracing.py | 303 ++++++++++ src/iterator/util/__init__.py | 0 src/iterator/util/sym_validation.py | 51 ++ src/iterator/utils.py | 19 + tests/iterator_tests/__init__.py | 0 tests/iterator_tests/fvm_nabla_setup.py | 204 +++++++ tests/iterator_tests/hdiff_reference.py | 41 ++ tests/iterator_tests/test_anton_toy.py | 83 +++ .../test_cartesian_offset_provider.py | 79 +++ tests/iterator_tests/test_column_stencil.py | 126 ++++ tests/iterator_tests/test_fvm_nabla.py | 285 +++++++++ tests/iterator_tests/test_hdiff.py | 70 +++ .../test_horizontal_indirection.py | 53 ++ tests/iterator_tests/test_toy_connectivity.py | 334 ++++++++++ .../iterator_tests/test_vertical_advection.py | 124 ++++ 32 files changed, 3091 insertions(+) create mode 100644 src/iterator/ARCHITECTURE.md create mode 100644 src/iterator/README.md create mode 100644 src/iterator/__init__.py create mode 100644 src/iterator/atlas_utils.py create mode 100644 src/iterator/backend_executor.py create mode 100644 src/iterator/backends/__init__.py create mode 100644 src/iterator/backends/backend.py create mode 100644 src/iterator/backends/cpptoy.py create mode 100644 src/iterator/backends/double_roundtrip.py create mode 100644 src/iterator/backends/embedded.py create mode 100644 src/iterator/backends/lisp.py create mode 100644 src/iterator/builtins.py create mode 100644 src/iterator/dispatcher.py create mode 100644 src/iterator/embedded.py create mode 100644 src/iterator/ir.py create mode 100644 src/iterator/library.py create mode 100644 src/iterator/runtime.py create mode 100644 src/iterator/tracing.py create mode 100644 src/iterator/util/__init__.py create mode 100644 src/iterator/util/sym_validation.py create mode 100644 src/iterator/utils.py create mode 100644 tests/iterator_tests/__init__.py create mode 100644 tests/iterator_tests/fvm_nabla_setup.py create mode 100644 tests/iterator_tests/hdiff_reference.py create mode 100644 tests/iterator_tests/test_anton_toy.py create mode 100644 tests/iterator_tests/test_cartesian_offset_provider.py create mode 100644 tests/iterator_tests/test_column_stencil.py create mode 100644 tests/iterator_tests/test_fvm_nabla.py create mode 100644 tests/iterator_tests/test_hdiff.py create mode 100644 tests/iterator_tests/test_horizontal_indirection.py create mode 100644 tests/iterator_tests/test_toy_connectivity.py create mode 100644 tests/iterator_tests/test_vertical_advection.py diff --git a/src/iterator/ARCHITECTURE.md b/src/iterator/ARCHITECTURE.md new file mode 100644 index 0000000000..decd642db9 --- /dev/null +++ b/src/iterator/ARCHITECTURE.md @@ -0,0 +1,60 @@ +# Architecture + +Implements the iterator view as described [here](https://github.com/GridTools/concepts/wiki/Iterator-View). + +## Iterator view program in Python + +A program for the iterator view consists of Python functions decorated with `@fundef` and an entry point, *fencil*, which is a Python function decorated with `@fendef`. The *fencil* must only contain calls to the `closure(...)` function. + +Legal functions much not have side-effects, however, e.g., for debugging purposes, side-effects can be used in embedded execution. + +There are 2 modes of execution: *embedded* (direct execution in Python) and *tracing* (trace function calls -> Eve IR representation -> code generation). +The implementations of *embedded* and *tracing* are decoupled by registering themselves (dependency inversion) in the functions defined in `builtins.py` (contains dispatch functions for all builtins of the model) and `runtime.py` (contains dispatch mechanism for the `fendef` and `fundef` decorators and the `closure(...)` function). + +The builtins dispatcher is implemented in `dispatcher.py`. Implementations are registered with a key (`str`) (currently `tracing` and `embedded`). The active implementation is selected by pushing a key to the dispatcher stack. + +`fundef` returns a wrapper around the function, which dispatches `__call__` to a hook if a predicate is met (used for *tracing*). By default the original function is called (used in *embedded* mode). + +`fendef` return a wrapper that dispatches to a registered function. Should be simplified. *Embedded* registers itself as default, *tracing* registers itself such that it's used if the fencil is called with `backend` keyword argument. + +## Embedded execution + +Embedded execution is implemented in the file `embedded.py`. + +Sketch: +- fields are np arrays with named axes; names are instances of `CartesianAxis` +- in `closure()`, the stencil is executed for each point in the domain, the fields are wrapped in an iterator pointing to the current point of execution. +- `shift()` is lazy (to allow lift implementation), offsets are accumulated in the iterator and only executed when `deref()` is called. +- as described in the design, offsets are abstract; on fencil execution the `offset_provider` keyword argument needs to be specified, which is a dict of `str` to either `CartesianAxis` or `NeighborTableOffsetProvider` +- if `column_axis` keyword argument is specified on fencil execution (or in the fencil decorator), all operations will be done column wise in the give axis; `column_axis` needs to be specified if `scan` is used + +## Tracing + +An iterator view program is traced (implemented in `tracing.py`) and represented in a tree structure defined by the nodes in (`ir.py`). + +Sketch: +- Each builtin returns a `FunctionCall` node representing the builtin. +- Foreach `fundef`, the signature of the wrapped function is extracted, then it is invoked with `Sym` nodes as arguments. +- Expressions involving an `Expr` node (e.g. `Sym`) are converted to appropriate builtin calls, e.g. `4. + Sym(id='foo')` is converted to `FunCall(fun=SymRef(id='plus'), args=...)` +- In appropriate places values are converted to nodes, see `make_node()`. +- Finally the IR tree will be passed to `execute_program()` in `backend_executor.py` which will generator code for the program (and execute, if appropriate). + +## Backends + +See directory `backends/`. + +### Cpptoy + +Generates C++ code in the spirit of https://github.com/GridTools/gridtools/pull/1643. Incomplete, will be adapted to the full C++ prototype. (only code generation) + +### Lisp + +Incomplete. Example for the grammar used in the model design document. (not executable) + +### Embedded + +Generates from the IR an aquivalent Python iterator view program which is then executed in embedded mode (round trip). + +### Double roundtrip + +Generates the Python iterator view program, traces it again, generates again and executes. Ensures that the generated Python code can still be traced. While the original program might be different from the generated program (e.g. `+` will be converted to `plus()` builtin). The programs from the embedded and double roundtrip backends should be identical. diff --git a/src/iterator/README.md b/src/iterator/README.md new file mode 100644 index 0000000000..7e59600739 --- /dev/null +++ b/src/iterator/README.md @@ -0,0 +1 @@ +# README diff --git a/src/iterator/__init__.py b/src/iterator/__init__.py new file mode 100644 index 0000000000..3d53598001 --- /dev/null +++ b/src/iterator/__init__.py @@ -0,0 +1,20 @@ +from typing import Optional, Union +from . import tracing + +from .builtins import * +from .runtime import * + +__all__ = ["builtins", "runtime"] + +from packaging.version import LegacyVersion, Version, parse +from pkg_resources import DistributionNotFound, get_distribution + + +try: + __version__: str = get_distribution("gt4py").version +except DistributionNotFound: + __version__ = "X.X.X.unknown" + +__versioninfo__: Optional[Union[LegacyVersion, Version]] = parse(__version__) + +del DistributionNotFound, LegacyVersion, Version, get_distribution, parse diff --git a/src/iterator/atlas_utils.py b/src/iterator/atlas_utils.py new file mode 100644 index 0000000000..dc6907186a --- /dev/null +++ b/src/iterator/atlas_utils.py @@ -0,0 +1,33 @@ +# GT4Py New Semantic Model - GridTools Framework +# +# Copyright (c) 2014-2021, ETH Zurich All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. GT4Py +# New Semantic Model is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the Free +# Software Foundation, either version 3 of the License, or any later version. +# See the LICENSE.txt file at the top-level directory of this distribution for +# a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from atlas4py import IrregularConnectivity + + +class AtlasTable: + def __init__(self, atlas_connectivity) -> None: + self.atlas_connectivity = atlas_connectivity + + def __getitem__(self, indices): + primary_index = indices[0] + neigh_index = indices[1] + if isinstance(self.atlas_connectivity, IrregularConnectivity): + if neigh_index < self.atlas_connectivity.cols(primary_index): + return self.atlas_connectivity[primary_index, neigh_index] + else: + return None + else: + if neigh_index < 2: + return self.atlas_connectivity[primary_index, neigh_index] + else: + assert False diff --git a/src/iterator/backend_executor.py b/src/iterator/backend_executor.py new file mode 100644 index 0000000000..2231ef3d35 --- /dev/null +++ b/src/iterator/backend_executor.py @@ -0,0 +1,20 @@ +from iterator.ir import Program +from iterator.backends import backend +from devtools import debug + + +def execute_program(prog: Program, *args, **kwargs): + assert "backend" in kwargs + assert len(prog.fencil_definitions) == 1 + + if "debug" in kwargs and kwargs["debug"]: + debug(prog) + + if not len(args) == len(prog.fencil_definitions[0].params): + raise RuntimeError("Incorrect number of arguments") + + if kwargs["backend"] in backend._BACKENDS: + b = backend.get_backend(kwargs["backend"]) + b(prog, *args, **kwargs) + else: + raise RuntimeError(f"Backend {kwargs['backend']} is not registered.") diff --git a/src/iterator/backends/__init__.py b/src/iterator/backends/__init__.py new file mode 100644 index 0000000000..ac69509c36 --- /dev/null +++ b/src/iterator/backends/__init__.py @@ -0,0 +1 @@ +from . import cpptoy, lisp, embedded, double_roundtrip diff --git a/src/iterator/backends/backend.py b/src/iterator/backends/backend.py new file mode 100644 index 0000000000..5c6fdf3f71 --- /dev/null +++ b/src/iterator/backends/backend.py @@ -0,0 +1,9 @@ +_BACKENDS = {} + + +def register_backend(name, backend): + _BACKENDS[name] = backend + + +def get_backend(name): + return _BACKENDS[name] diff --git a/src/iterator/backends/cpptoy.py b/src/iterator/backends/cpptoy.py new file mode 100644 index 0000000000..8f81eb0f48 --- /dev/null +++ b/src/iterator/backends/cpptoy.py @@ -0,0 +1,49 @@ +from typing import Any +from eve import codegen +from eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako +from iterator.ir import OffsetLiteral +from iterator.backends import backend + + +class ToyCpp(codegen.TemplatedGenerator): + Sym = as_fmt("{id}") + SymRef = as_fmt("{id}") + IntLiteral = as_fmt("{value}") + FloatLiteral = as_fmt("{value}") + AxisLiteral = as_fmt("{value}") + + def visit_OffsetLiteral(self, node: OffsetLiteral, **kwargs): + return node.value if isinstance(node.value, str) else f"{node.value}_c" + + StringLiteral = as_fmt("{value}") + FunCall = as_fmt("{fun}({','.join(args)})") + Lambda = as_mako( + "[=](${','.join('auto ' + p for p in params)}){return ${expr};}" + ) # TODO capture + StencilClosure = as_mako( + "closure(${domain}, ${stencil}, out(${','.join(outputs)}), ${','.join(inputs)})" + ) + FencilDefinition = as_mako( + """ + auto ${id} = [](${','.join('auto&& ' + p for p in params)}){ + fencil(${'\\n'.join(closures)}); + }; + """ + ) + FunctionDefinition = as_mako( + """ + inline constexpr auto ${id} = [](${','.join('auto ' + p for p in params)}){ + return ${expr}; + }; + """ + ) + Program = as_fmt("{''.join(function_definitions)} {''.join(fencil_definitions)}") + + @classmethod + def apply(cls, root, **kwargs: Any) -> str: + generated_code = super().apply(root, **kwargs) + formatted_code = codegen.format_source("cpp", generated_code, style="LLVM") + return formatted_code + + +backend.register_backend("cpptoy", lambda prog, *args, **kwargs: print(ToyCpp.apply(prog))) diff --git a/src/iterator/backends/double_roundtrip.py b/src/iterator/backends/double_roundtrip.py new file mode 100644 index 0000000000..ea1d57c263 --- /dev/null +++ b/src/iterator/backends/double_roundtrip.py @@ -0,0 +1,9 @@ +from eve.concepts import Node +from iterator.backends import backend, embedded + + +def executor(ir: Node, *args, **kwargs): + embedded.executor(ir, *args, dispatch_backend=embedded._BACKEND_NAME, **kwargs) + + +backend.register_backend("double_roundtrip", executor) diff --git a/src/iterator/backends/embedded.py b/src/iterator/backends/embedded.py new file mode 100644 index 0000000000..a495f556c8 --- /dev/null +++ b/src/iterator/backends/embedded.py @@ -0,0 +1,103 @@ +from eve import codegen +from eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako +from eve.concepts import Node +from iterator.ir import AxisLiteral, NoneLiteral, OffsetLiteral +from iterator.backends import backend +import tempfile +import importlib.util +import iterator + + +class EmbeddedDSL(codegen.TemplatedGenerator): + Sym = as_fmt("{id}") + SymRef = as_fmt("{id}") + BoolLiteral = as_fmt("{value}") + IntLiteral = as_fmt("{value}") + FloatLiteral = as_fmt("{value}") + NoneLiteral = as_fmt("None") + OffsetLiteral = as_fmt("{value}") + AxisLiteral = as_fmt("{value}") + StringLiteral = as_fmt("{value}") + FunCall = as_fmt("{fun}({','.join(args)})") + Lambda = as_mako("lambda ${','.join(params)}: ${expr}") + StencilClosure = as_mako( + "closure(${domain}, ${stencil}, [${','.join(outputs)}], [${','.join(inputs)}])" + ) + FencilDefinition = as_mako( + """ +@fendef +def ${id}(${','.join(params)}): + ${'\\n '.join(closures)} + """ + ) + FunctionDefinition = as_mako( + """ +@fundef +def ${id}(${','.join(params)}): + return ${expr} + """ + ) + Program = as_fmt( + """ +{''.join(function_definitions)} {''.join(fencil_definitions)}""" + ) + + +_BACKEND_NAME = "embedded" + + +def executor(ir: Node, *args, **kwargs): + debug = "debug" in kwargs and kwargs["debug"] == True + + program = EmbeddedDSL.apply(ir) + offset_literals = ( + ir.iter_tree().if_isinstance(OffsetLiteral).getattr("value").if_isinstance(str).to_set() + ) + axis_literals = ir.iter_tree().if_isinstance(AxisLiteral).getattr("value").to_set() + with tempfile.NamedTemporaryFile( + mode="w", + suffix=".py", + delete=not debug, + ) as tmp: + if debug: + print(tmp.name) + header = """ +from iterator.builtins import * +from iterator.runtime import * +""" + offset_literals = [f'{l} = offset("{l}")' for l in offset_literals] + axis_literals = [f'{l} = CartesianAxis("{l}")' for l in axis_literals] + tmp.write(header) + tmp.write("\n".join(offset_literals)) + tmp.write("\n") + tmp.write("\n".join(axis_literals)) + tmp.write("\n") + tmp.write(program) + tmp.flush() + + spec = importlib.util.spec_from_file_location("module.name", tmp.name) + foo = importlib.util.module_from_spec(spec) + spec.loader.exec_module(foo) + + fencil_name = ir.fencil_definitions[0].id + fencil = getattr(foo, fencil_name) + assert "offset_provider" in kwargs + + new_kwargs = {} + new_kwargs["offset_provider"] = kwargs["offset_provider"] + if "column_axis" in kwargs: + new_kwargs["column_axis"] = kwargs["column_axis"] + + if not "dispatch_backend" in kwargs: + iterator.builtins.builtin_dispatch.push_key("embedded") + fencil(*args, **new_kwargs) + iterator.builtins.builtin_dispatch.pop_key() + else: + fencil( + *args, + **new_kwargs, + backend=kwargs["dispatch_backend"], + ) + + +backend.register_backend(_BACKEND_NAME, executor) diff --git a/src/iterator/backends/lisp.py b/src/iterator/backends/lisp.py new file mode 100644 index 0000000000..c4eed300dd --- /dev/null +++ b/src/iterator/backends/lisp.py @@ -0,0 +1,60 @@ +from typing import Any +from eve.codegen import TemplatedGenerator +from eve.codegen import FormatTemplate as as_fmt + +# from yasi import indent_code + +from iterator.backends import backend + + +class ToLispLike(TemplatedGenerator): + Sym = as_fmt("{id}") + FunCall = as_fmt("({fun} {' '.join(args)})") + IntLiteral = as_fmt("{value}") + OffsetLiteral = as_fmt("{value}") + StringLiteral = as_fmt("{value}") + SymRef = as_fmt("{id}") + Program = as_fmt( + """ + {''.join(function_definitions)} + {''.join(fencil_definitions)} + {''.join(setqs)} + """ + ) + StencilClosure = as_fmt( + """( + :domain {domain} + :stencil {stencil} + :outputs {' '.join(outputs)} + :inputs {' '.join(inputs)} + ) + """ + ) + FencilDefinition = as_fmt( + """(defen {id}({' '.join(params)}) + {''.join(closures)}) + """ + ) + FunctionDefinition = as_fmt( + """(defun {id}({' '.join(params)}) + {expr} + ) + +""" + ) + Lambda = as_fmt( + """(lambda ({' '.join(params)}) + {expr} + )""" + ) + + @classmethod + def apply(cls, root, **kwargs: Any) -> str: + generated_code = super().apply(root, **kwargs) + return generated_code + # indented = indent_code(generated_code, "--dialect lisp") + # formatted_code = "".join(indented["indented_code"]) + # return formatted_code + + +backend.register_backend("lisp", lambda prog, *args, **kwargs: print(ToLispLike.apply(prog))) diff --git a/src/iterator/builtins.py b/src/iterator/builtins.py new file mode 100644 index 0000000000..5e2ddecef3 --- /dev/null +++ b/src/iterator/builtins.py @@ -0,0 +1,113 @@ +from iterator.dispatcher import Dispatcher + +__all__ = [ + "deref", + "shift", + "lift", + "reduce", + "scan", + "is_none", + "domain", + "named_range", + "compose", + "if_", + "minus", + "plus", + "mul", + "div", + "greater", + "make_tuple", + "nth", +] + +builtin_dispatch = Dispatcher() + + +class BackendNotSelectedError(RuntimeError): + def __init__(self) -> None: + super().__init__("Backend not selected") + + +@builtin_dispatch +def deref(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def shift(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def lift(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def reduce(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def scan(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def is_none(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def domain(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def named_range(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def compose(sten): + raise BackendNotSelectedError() + + +@builtin_dispatch +def if_(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def minus(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def plus(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def mul(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def div(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def greater(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def make_tuple(*args): + raise BackendNotSelectedError + + +@builtin_dispatch +def nth(*args): + raise BackendNotSelectedError diff --git a/src/iterator/dispatcher.py b/src/iterator/dispatcher.py new file mode 100644 index 0000000000..15e9a93865 --- /dev/null +++ b/src/iterator/dispatcher.py @@ -0,0 +1,56 @@ +from typing import Any + +# TODO test + + +class _fun_dispatcher: + def __init__(self, dispatcher, fun) -> None: + self.dispatcher = dispatcher + self.fun = fun + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + if self.dispatcher.key is None: + return self.fun(*args, **kwargs) + else: + return self.dispatcher._funs[self.dispatcher.key][self.fun.__name__]( + *args, **kwargs + ) + + def register(self, key): + self.dispatcher.register_key(key) + + def _impl(fun): + self.dispatcher._funs[key][self.fun.__name__] = fun + + return _impl + + +class Dispatcher: + def __init__(self) -> None: + self._funs = {} + self.key_stack = [] + + @property + def key(self): + return self.key_stack[-1] if self.key_stack else None + + def register_key(self, key): + if not key in self._funs: + self._funs[key] = {} + + def push_key(self, key): + if key not in self._funs: + raise RuntimeError(f"Key {key} not registered") + self.key_stack.append(key) + + def pop_key(self): + self.key_stack.pop() + + def clear_key(self): + self.key_stack = [] + + def __call__(self, fun): + return self.dispatch(fun) + + def dispatch(self, fun): + return _fun_dispatcher(self, fun) diff --git a/src/iterator/embedded.py b/src/iterator/embedded.py new file mode 100644 index 0000000000..2344f5f5ba --- /dev/null +++ b/src/iterator/embedded.py @@ -0,0 +1,570 @@ +from dataclasses import dataclass +import itertools + +import iterator +from iterator.builtins import ( + builtin_dispatch, + is_none, + lift, + reduce, + shift, + deref, + scan, + domain, + named_range, + if_, + minus, + plus, + mul, + div, + greater, + nth, + make_tuple, +) +from iterator.runtime import CartesianAxis, Offset +from iterator.utils import tupelize +import numpy as np +import numbers + +EMBEDDED = "embedded" + + +class NeighborTableOffsetProvider: + def __init__(self, tbl, origin_axis, neighbor_axis, max_neighbors) -> None: + self.tbl = tbl + self.origin_axis = origin_axis + self.neighbor_axis = neighbor_axis + self.max_neighbors = max_neighbors + + +@deref.register(EMBEDDED) +def deref(iter): + return iter.deref() + + +@if_.register(EMBEDDED) +def if_(cond, t, f): + return t if cond else f + + +@nth.register(EMBEDDED) +def nth(i, tup): + return tup[i] + + +@make_tuple.register(EMBEDDED) +def make_tuple(*args): + return (*args,) + + +@lift.register(EMBEDDED) +def lift(stencil): + def impl(*args): + class wrap_iterator: + def __init__(self, *, offsets=[], elem=None) -> None: + self.offsets = offsets + self.elem = elem + + # TODO needs to be supported by all iterators that represent tuples + def __getitem__(self, index): + return wrap_iterator(offsets=self.offsets, elem=index) + + def shift(self, *offsets): + return wrap_iterator(offsets=[*offsets, *self.offsets], elem=self.elem) + + def max_neighbors(self): + # TODO cleanup, test edge cases + open_offsets = get_open_offsets(*self.offsets) + assert open_offsets + assert isinstance( + args[0].offset_provider[open_offsets[0].value], + NeighborTableOffsetProvider, + ) + return args[0].offset_provider[open_offsets[0].value].max_neighbors + + def deref(self): + class DelayedIterator: + def __init__(self, wrapped_iterator, lifted_offsets, *, offsets=[]) -> None: + self.wrapped_iterator = wrapped_iterator + self.lifted_offsets = lifted_offsets + self.offsets = offsets + + def is_none(self): + shifted = self.wrapped_iterator.shift(*self.lifted_offsets, *self.offsets) + return shifted.is_none() + + def max_neighbors(self): + shifted = self.wrapped_iterator.shift(*self.lifted_offsets, *self.offsets) + return shifted.max_neighbors() + + def shift(self, *offsets): + return DelayedIterator( + self.wrapped_iterator, + self.lifted_offsets, + offsets=[*offsets, *self.offsets], + ) + + def deref(self): + shifted = self.wrapped_iterator.shift(*self.lifted_offsets, *self.offsets) + return shifted.deref() + + shifted_args = tuple(map(lambda arg: DelayedIterator(arg, self.offsets), args)) + + if any(shifted_arg.is_none() for shifted_arg in shifted_args): + return None + + if self.elem is None: + return stencil(*shifted_args) + else: + return stencil(*shifted_args)[self.elem] + + return wrap_iterator() + + return impl + + +@reduce.register(EMBEDDED) +def reduce(fun, init): + def sten(*iters): + # assert check_that_all_iterators_are_compatible(*iters) + first_it = iters[0] + n = first_it.max_neighbors() + res = init + for i in range(n): + # we can check a single argument + # because all arguments share the same pattern + if iterator.builtins.deref(iterator.builtins.shift(i)(first_it)) is None: + break + res = fun( + res, + *(iterator.builtins.deref(iterator.builtins.shift(i)(it)) for it in iters), + ) + return res + + return sten + + +class _None: + """Dummy object to allow execution of expression containing Nones in non-active path + + E.g. + `if_(is_none(state), 42, 42+state)` + here 42+state needs to be evaluatable even if is_none(state) + + TODO: all possible arithmetic operations + """ + + def __add__(self, other): + return _None() + + def __radd__(self, other): + return _None() + + def __sub__(self, other): + return _None() + + def __rsub__(self, other): + return _None() + + def __mul__(self, other): + return _None() + + def __rmul__(self, other): + return _None() + + def __truediv__(self, other): + return _None() + + def __rtruediv__(self, other): + return _None() + + def __getitem__(self, i): + return _None() + + +@is_none.register(EMBEDDED) +def is_none(arg): + return isinstance(arg, _None) + + +@domain.register(EMBEDDED) +def domain(*args): + domain = {} + for arg in args: + domain.update(arg) + return domain + + +@named_range.register(EMBEDDED) +def named_range(tag, start, end): + return {tag: range(start, end)} + + +@minus.register(EMBEDDED) +def minus(first, second): + return first - second + + +@plus.register(EMBEDDED) +def plus(first, second): + return first + second + + +@mul.register(EMBEDDED) +def mul(first, second): + return first * second + + +@div.register(EMBEDDED) +def div(first, second): + return first / second + + +@greater.register(EMBEDDED) +def greater(first, second): + return first > second + + +def named_range(axis, range): + return ((axis, i) for i in range) + + +def domain_iterator(domain): + return ( + dict(elem) + for elem in itertools.product(*map(lambda tup: named_range(tup[0], tup[1]), domain.items())) + ) + + +def execute_shift(pos, tag, index, *, offset_provider): + if tag in pos and pos[tag] is None: # sparse field with offset as neighbor dimension + new_pos = pos.copy() + new_pos[tag] = index + return new_pos + assert tag.value in offset_provider + offset_implementation = offset_provider[tag.value] + if isinstance(offset_implementation, CartesianAxis): + assert offset_implementation in pos + new_pos = pos.copy() + new_pos[offset_implementation] += index + return new_pos + elif isinstance(offset_implementation, NeighborTableOffsetProvider): + assert offset_implementation.origin_axis in pos + new_pos = pos.copy() + del new_pos[offset_implementation.origin_axis] + if offset_implementation.tbl[pos[offset_implementation.origin_axis], index] is None: + return None + else: + new_pos[offset_implementation.neighbor_axis] = offset_implementation.tbl[ + pos[offset_implementation.origin_axis], index + ] + return new_pos + + assert False + + +# The following holds for shifts: +# shift(tag, index)(inp) -> full shift +# shift(tag)(inp) -> incomplete shift +# shift(index)(shift(tag)(inp)) -> full shift +# Therefore the following transformation holds +# shift(e2c,0)(shift(v2c,2)(cell_field)) +# = shift(0)(shift(e2c)(shift(2)(shift(v2c)(cell_field)))) +# = shift(v2c, 2, e2c, 0)(cell_field) +# = shift(v2c,e2c,2,0)(cell_field) <-- v2c,e2c twice incomplete shift +# = shift(2,0)(shift(v2c,e2c)(cell_field)) +# for implementations it means everytime we have an index, we can "execute" a concrete shift +def group_offsets(*offsets): + tag_stack = [] + index_stack = [] + complete_offsets = [] + for offset in offsets: + if not isinstance(offset, int): + if index_stack: + index = index_stack.pop(0) + complete_offsets.append((offset, index)) + else: + tag_stack.append(offset) + else: + assert not tag_stack + index_stack.append(offset) + return complete_offsets, tag_stack + + +def shift_position(pos, *offsets, offset_provider): + complete_offsets, open_offsets = group_offsets(*offsets) + # assert not open_offsets # TODO enable this, check failing test and make everything saver + + new_pos = pos.copy() + for tag, index in complete_offsets: + new_pos = execute_shift(new_pos, tag, index, offset_provider=offset_provider) + if new_pos is None: + return None + return new_pos + + +def get_open_offsets(*offsets): + return group_offsets(*offsets)[1] + + +class MDIterator: + def __init__(self, field, pos, *, offsets=[], offset_provider, column_axis=None) -> None: + self.field = field + self.pos = pos + self.offsets = offsets + self.offset_provider = offset_provider + self.column_axis = column_axis + + def shift(self, *offsets): + return MDIterator( + self.field, + self.pos, + offsets=[*offsets, *self.offsets], + offset_provider=self.offset_provider, + column_axis=self.column_axis, + ) + + def max_neighbors(self): + open_offsets = get_open_offsets(*self.offsets) + assert open_offsets + assert isinstance(self.offset_provider[open_offsets[0].value], NeighborTableOffsetProvider) + return self.offset_provider[open_offsets[0].value].max_neighbors + + def is_none(self): + return shift_position(self.pos, *self.offsets, offset_provider=self.offset_provider) is None + + def deref(self): + shifted_pos = shift_position(self.pos, *self.offsets, offset_provider=self.offset_provider) + + if not all(axis in shifted_pos.keys() for axis in self.field.axises): + raise IndexError("Iterator position doesn't point to valid location for its field.") + slice_column = {} + if self.column_axis is not None: + slice_column[self.column_axis] = slice(shifted_pos[self.column_axis], None) + del shifted_pos[self.column_axis] + ordered_indices = get_ordered_indices( + self.field.axises, + shifted_pos, + slice_axises=slice_column, + ) + return self.field[ordered_indices] + + +def make_in_iterator(inp, pos, offset_provider, *, column_axis): + sparse_dimensions = [axis for axis in inp.axises if isinstance(axis, Offset)] + assert len(sparse_dimensions) <= 1 # TODO multiple is not a current use case + new_pos = pos.copy() + for axis in sparse_dimensions: + new_pos[axis] = None + if column_axis is not None: + # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted + new_pos[column_axis] = 0 + return MDIterator( + inp, + new_pos, + offsets=[*sparse_dimensions], + offset_provider=offset_provider, + column_axis=column_axis, + ) + + +builtin_dispatch.push_key(EMBEDDED) # makes embedded the default + + +class LocatedField: + """A Field with named dimensions/axises. + + Axis keys can be any objects that are hashable. + """ + + def __init__(self, getter, axises, *, setter=None, array=None): + self.getter = getter + self.axises = axises + self.setter = setter + self.array = array + + def __getitem__(self, indices): + indices = tupelize(indices) + return self.getter(indices) + + def __setitem__(self, indices, value): + if self.setter is None: + raise TypeError("__setitem__ not supported for this field") + self.setter(indices, value) + + def __array__(self): + if self.array is None: + raise TypeError("__array__ not supported for this field") + return self.array() + + @property + def shape(self): + if self.array is None: + raise TypeError("`shape` not supported for this field") + return self.array().shape + + +def get_ordered_indices(axises, pos, *, slice_axises={}): + """pos is a dictionary from axis to offset""" + assert all(axis in [*pos.keys(), *slice_axises] for axis in axises) + return tuple(pos[axis] if axis in pos else slice_axises[axis] for axis in axises) + + +def _tupsum(a, b): + def combine_slice(s, t): + is_slice = False + if isinstance(s, slice): + is_slice = True + first = s.start + assert s.step is None + assert s.stop is None + else: + assert isinstance(s, numbers.Integral) + first = s + if isinstance(t, slice): + is_slice = True + second = t.start + assert t.step is None + assert t.stop is None + else: + assert isinstance(t, numbers.Integral) + second = t + start = first + second + return slice(start, None) if is_slice else start + + return tuple(combine_slice(*i) for i in zip(a, b)) + + +def np_as_located_field(*axises, origin=None): + def _maker(a: np.ndarray): + if a.ndim != len(axises): + raise TypeError("ndarray.ndim incompatible with number of given axises") + + if origin is not None: + offsets = get_ordered_indices(axises, origin) + else: + offsets = tuple(0 for _ in axises) + + def setter(indices, value): + a[_tupsum(indices, offsets)] = value + + def getter(indices): + return a[_tupsum(indices, offsets)] + + return LocatedField(getter, axises, setter=setter, array=a.__array__) + + return _maker + + +def index_field(axis): + return LocatedField(lambda index: index[0], (axis,)) + + +@iterator.builtins.shift.register(EMBEDDED) +def shift(*offsets): + def impl(iter): + return iter.shift(*reversed(offsets)) + + return impl + + +@dataclass +class Column: + axis: CartesianAxis + range: range + + +class ScanArgIterator: + def __init__(self, wrapped_iter, k_pos, *, offsets=[]) -> None: + self.wrapped_iter = wrapped_iter + self.offsets = offsets + self.k_pos = k_pos + + def deref(self): + return self.wrapped_iter.deref()[self.k_pos] + + def shift(self, *offsets): + return ScanArgIterator(self.wrapped_iter, offsets=[*offsets, *self.offsets]) + + +def shifted_scan_arg(k_pos): + def impl(iter): + return ScanArgIterator(iter, k_pos=k_pos) + + return impl + + +def fendef_embedded(fun, *args, **kwargs): + assert "offset_provider" in kwargs + + @iterator.runtime.closure.register(EMBEDDED) + def closure(domain, sten, outs, ins): # domain is Dict[axis, range] + + column = None + if "column_axis" in kwargs: + _column_axis = kwargs["column_axis"] + column = Column(_column_axis, domain[_column_axis]) + del domain[_column_axis] + + @iterator.builtins.scan.register( + EMBEDDED + ) # TODO this is a bit ugly, alternative: pass scan range via iterator + def scan(scan_pass, is_forward, init): + def impl(*iters): + if column is None: + raise RuntimeError("Column axis is not defined, cannot scan.") + + _range = column.range + if not is_forward: + _range = reversed(_range) + + state = init + if state is None: + state = _None() + cols = [] + for i in _range: + state = scan_pass( + state, *map(shifted_scan_arg(i), iters) + ) # more generic scan returns state and result as 2 different things + cols.append([*tupelize(state)]) + + cols = tuple(map(np.asarray, (map(list, zip(*cols))))) + # transpose to get tuple of columns as np array + + if not is_forward: + cols = tuple(map(np.flip, cols)) + return cols + + return impl + + for pos in domain_iterator(domain): + ins_iters = list( + make_in_iterator( + inp, + pos, + kwargs["offset_provider"], + column_axis=column.axis if column is not None else None, + ) + for inp in ins + ) + res = sten(*ins_iters) + if not isinstance(res, tuple): + res = (res,) + if not len(res) == len(outs): + IndexError("Number of return values doesn't match number of output fields.") + + for r, out in zip(res, outs): + if column is None: + ordered_indices = get_ordered_indices(out.axises, pos) + out[ordered_indices] = r + else: + colpos = pos.copy() + for k in column.range: + colpos[column.axis] = k + ordered_indices = get_ordered_indices(out.axises, colpos) + out[ordered_indices] = r[k] + + fun(*args) + + +iterator.runtime.fendef_registry[None] = fendef_embedded diff --git a/src/iterator/ir.py b/src/iterator/ir.py new file mode 100644 index 0000000000..856becdea7 --- /dev/null +++ b/src/iterator/ir.py @@ -0,0 +1,116 @@ +from typing import List, Union +from eve import Node +from eve.traits import SymbolName, SymbolTableTrait +from eve.type_definitions import SymbolRef +from iterator.util.sym_validation import validate_symbol_refs + + +class Sym(Node): # helper + id: SymbolName + + +class Expr(Node): + ... + + +class BoolLiteral(Expr): + value: bool + + +class IntLiteral(Expr): + value: int + + +class FloatLiteral(Expr): + value: float # TODO other float types + + +class StringLiteral(Expr): + value: str + + +class NoneLiteral(Expr): + _none_literal: int = 0 + + +class OffsetLiteral(Expr): + value: Union[int, str] + + +class AxisLiteral(Expr): + value: str + + +class SymRef(Expr): + id: SymbolRef + + +class Lambda(Expr, SymbolTableTrait): + params: List[Sym] + expr: Expr + + +class FunCall(Expr): + fun: Expr # VType[Callable] + args: List[Expr] + + +class FunctionDefinition(Node, SymbolTableTrait): + id: SymbolName + params: List[Sym] + expr: Expr + + def __eq__(self, other): + return isinstance(other, FunctionDefinition) and self.id == other.id + + def __hash__(self): + return hash(self.id) + + +class Setq(Node): + id: SymbolName + expr: Expr + + +class StencilClosure(Node): + domain: Expr + stencil: Expr + outputs: List[SymRef] + inputs: List[SymRef] + + +class FencilDefinition(Node, SymbolTableTrait): + id: SymbolName + params: List[Sym] + closures: List[StencilClosure] + + +class Program(Node, SymbolTableTrait): + function_definitions: List[FunctionDefinition] + fencil_definitions: List[FencilDefinition] + setqs: List[Setq] + + builtin_functions = list( + Sym(id=name) + for name in [ + "domain", + "named_range", + "compose", + "lift", + "is_none", + "make_tuple", + "nth", + "reduce", + "deref", + "shift", + "scan", + "plus", + "minus", + "mul", + "div", + "greater", + "less", + "if_", + ] + ) + _validate_symbol_refs = validate_symbol_refs() diff --git a/src/iterator/library.py b/src/iterator/library.py new file mode 100644 index 0000000000..b0e1dbfb09 --- /dev/null +++ b/src/iterator/library.py @@ -0,0 +1,15 @@ +from iterator.builtins import reduce + + +def sum(fun=None): + if fun is None: + return reduce(lambda a, b: a + b, 0) + else: + return reduce(lambda first, a, b: first + fun(a, b), 0) # TODO tracing for *args + + +def dot(a, b): + return reduce(lambda acc, a_n, c_n: acc + a_n * c_n, 0)(a, b) + + +# dot = reduce(lambda a, b, c: a + b * c, 0) diff --git a/src/iterator/runtime.py b/src/iterator/runtime.py new file mode 100644 index 0000000000..5b74a456c4 --- /dev/null +++ b/src/iterator/runtime.py @@ -0,0 +1,84 @@ +from typing import Union +from dataclasses import dataclass + +from iterator.builtins import BackendNotSelectedError, builtin_dispatch + +__all__ = ["offset", "fundef", "fendef", "closure", "CartesianAxis"] + + +@dataclass +class Offset: + value: Union[int, str] = None + + def __hash__(self) -> int: + return hash(self.value) + + +def offset(value): + return Offset(value) + + +@dataclass +class CartesianAxis: + value: str + + def __hash__(self) -> int: + return hash(self.value) + + +fendef_registry = {} + + +# TODO the dispatching is linear, not sure if there is an easy way to make it constant +def fendef(*dec_args, **dec_kwargs): + def wrapper(fun): + def impl(*args, **kwargs): + kwargs = {**kwargs, **dec_kwargs} + + for key, val in fendef_registry.items(): + if key is not None and key(kwargs): + val(fun, *args, **kwargs) + return + if None in fendef_registry: + fendef_registry[None](fun, *args, **kwargs) + return + raise RuntimeError("Unreachable") + + return impl + + if len(dec_args) == 1 and len(dec_kwargs) == 0 and callable(dec_args[0]): + return wrapper(dec_args[0]) + else: + assert len(dec_args) == 0 + return wrapper + + +class FundefDispatcher: + _hook = None + # hook is an object that + # - evaluates to true if it should be used, + # - is callable with an instance of FundefDispatcher + # - returns callable that takes the function arguments + + def __init__(self, fun) -> None: + self.fun = fun + self.__name__ = fun.__name__ + + def __call__(self, *args): + if type(self)._hook: + return type(self)._hook(self)(*args) + else: + return self.fun(*args) + + @classmethod + def register_hook(cls, hook): + cls._hook = hook + + +def fundef(fun): + return FundefDispatcher(fun) + + +@builtin_dispatch +def closure(*args): + return BackendNotSelectedError() diff --git a/src/iterator/tracing.py b/src/iterator/tracing.py new file mode 100644 index 0000000000..74e4e1d4c6 --- /dev/null +++ b/src/iterator/tracing.py @@ -0,0 +1,303 @@ +from iterator.runtime import CartesianAxis +from eve import Node +import inspect +from iterator.backend_executor import execute_program +from iterator.ir import ( + AxisLiteral, + Expr, + FencilDefinition, + FloatLiteral, + FunCall, + FunctionDefinition, + IntLiteral, + Lambda, + NoneLiteral, + BoolLiteral, + OffsetLiteral, + Program, + StencilClosure, + Sym, + SymRef, +) +from iterator.backends import backend +import iterator + +TRACING = "tracing" + + +def monkeypatch_method(cls): + def decorator(func): + setattr(cls, func.__name__, func) + return func + + return decorator + + +def _patch_Expr(): + @monkeypatch_method(Expr) + def __add__(self, other): + return FunCall(fun=SymRef(id="plus"), args=[self, make_node(other)]) + + @monkeypatch_method(Expr) + def __radd__(self, other): + return make_node(other) + self + + @monkeypatch_method(Expr) + def __mul__(self, other): + return FunCall(fun=SymRef(id="mul"), args=[self, make_node(other)]) + + @monkeypatch_method(Expr) + def __rmul__(self, other): + return make_node(other) * self + + @monkeypatch_method(Expr) + def __truediv__(self, other): + return FunCall(fun=SymRef(id="div"), args=[self, make_node(other)]) + + @monkeypatch_method(Expr) + def __sub__(self, other): + return FunCall(fun=SymRef(id="minus"), args=[self, make_node(other)]) + + @monkeypatch_method(Expr) + def __gt__(self, other): + return FunCall(fun=SymRef(id="greater"), args=[self, make_node(other)]) + + @monkeypatch_method(Expr) + def __lt__(self, other): + return FunCall(fun=SymRef(id="less"), args=[self, make_node(other)]) + + @monkeypatch_method(Expr) + def __call__(self, *args): + return FunCall(fun=self, args=[*make_node(args)]) + + +_patch_Expr() + + +class PatchedFunctionDefinition(FunctionDefinition): + def __call__(self, *args): + return FunCall(fun=SymRef(id=str(self.id)), args=[*make_node(args)]) + + +def _s(id): + return SymRef(id=id) + + +def trace_function_argument(arg): + if isinstance(arg, iterator.runtime.FundefDispatcher): + make_function_definition(arg.fun) + return _s(arg.fun.__name__) + return arg + + +def _f(fun, *args): + if isinstance(fun, str): + fun = _s(fun) + + args = [trace_function_argument(arg) for arg in args] + return FunCall(fun=fun, args=[*make_node(args)]) + + +# builtins +@iterator.builtins.deref.register(TRACING) +def deref(arg): + return _f("deref", arg) + + +@iterator.builtins.lift.register(TRACING) +def lift(sten): + return _f("lift", sten) + + +@iterator.builtins.reduce.register(TRACING) +def reduce(*args): + return _f("reduce", *args) + + +@iterator.builtins.scan.register(TRACING) +def scan(*args): + return _f("scan", *args) + + +@iterator.builtins.is_none.register(TRACING) +def is_none(*args): + return _f("is_none", *args) + + +@iterator.builtins.make_tuple.register(TRACING) +def make_tuple(*args): + return _f("make_tuple", *args) + + +@iterator.builtins.nth.register(TRACING) +def nth(*args): + return _f("nth", *args) + + +@iterator.builtins.compose.register(TRACING) +def compose(*args): + return _f("compose", *args) + + +@iterator.builtins.domain.register(TRACING) +def domain(*args): + return _f("domain", *args) + + +@iterator.builtins.named_range.register(TRACING) +def named_range(*args): + return _f("named_range", *args) + + +@iterator.builtins.if_.register(TRACING) +def if_(*args): + return _f("if_", *args) + + +# shift promotes its arguments to literals, therefore special +@iterator.builtins.shift.register(TRACING) +def shift(*offsets): + offsets = tuple(OffsetLiteral(value=o) if isinstance(o, (str, int)) else o for o in offsets) + return _f("shift", *offsets) + + +@iterator.builtins.plus.register(TRACING) +def plus(*args): + return _f("plus", *args) + + +@iterator.builtins.minus.register(TRACING) +def minus(*args): + return _f("minus", *args) + + +@iterator.builtins.mul.register(TRACING) +def mul(*args): + return _f("mul", *args) + + +@iterator.builtins.div.register(TRACING) +def div(*args): + return _f("div", *args) + + +@iterator.builtins.greater.register(TRACING) +def greater(*args): + return _f("greater", *args) + + +# helpers +def make_node(o): + if isinstance(o, Node): + return o + if callable(o): + if o.__name__ == "": + return lambdadef(o) + if hasattr(o, "__code__") and o.__code__.co_flags & inspect.CO_NESTED: + return lambdadef(o) + if isinstance(o, iterator.runtime.Offset): + return OffsetLiteral(value=o.value) + if isinstance(o, bool): + return BoolLiteral(value=o) + if isinstance(o, int): + return IntLiteral(value=o) + if isinstance(o, float): + return FloatLiteral(value=o) + if isinstance(o, CartesianAxis): + return AxisLiteral(value=o.value) + if isinstance(o, tuple): + return tuple(make_node(arg) for arg in o) + if isinstance(o, list): + return list(make_node(arg) for arg in o) + if o is None: + return NoneLiteral() + raise NotImplementedError(f"Cannot handle {o}") + + +def trace_function_call(fun): + body = fun(*list(_s(param) for param in inspect.signature(fun).parameters.keys())) + return make_node(body) if body is not None else None + + +def lambdadef(fun): + return Lambda( + params=list(Sym(id=param) for param in inspect.signature(fun).parameters.keys()), + expr=trace_function_call(fun), + ) + + +def make_function_definition(fun): + res = PatchedFunctionDefinition( + id=fun.__name__, + params=list(Sym(id=param) for param in inspect.signature(fun).parameters.keys()), + expr=trace_function_call(fun), + ) + Tracer.add_fundef(res) + return res + + +class FundefTracer: + def __call__(self, fundef_dispatcher: iterator.runtime.FundefDispatcher): + def fun(*args): + res = make_function_definition(fundef_dispatcher.fun) + return res(*args) + + return fun + + def __bool__(self): + return iterator.builtins.builtin_dispatch.key == TRACING + + +iterator.runtime.FundefDispatcher.register_hook(FundefTracer()) + + +class Tracer: + fundefs = [] + closures = [] + + @classmethod + def add_fundef(cls, fun): + if not fun in cls.fundefs: + cls.fundefs.append(fun) + + @classmethod + def add_closure(cls, closure): + cls.closures.append(closure) + + def __enter__(self): + iterator.builtins.builtin_dispatch.push_key(TRACING) + + def __exit__(self, exc_type, exc_value, exc_traceback): + type(self).fundefs = [] + type(self).closures = [] + iterator.builtins.builtin_dispatch.pop_key() + + +@iterator.runtime.closure.register(TRACING) +def closure(domain, stencil, outputs, inputs): + stencil(*list(_s(param) for param in inspect.signature(stencil).parameters.keys())) + Tracer.add_closure( + StencilClosure( + domain=domain, + stencil=SymRef(id=str(stencil.__name__)), + outputs=outputs, + inputs=inputs, + ) + ) + + +def fendef_tracing(fun, *args, **kwargs): + with Tracer() as _: + trace_function_call(fun) + + fencil = FencilDefinition( + id=fun.__name__, + params=list(Sym(id=param) for param in inspect.signature(fun).parameters.keys()), + closures=Tracer.closures, + ) + prog = Program(function_definitions=Tracer.fundefs, fencil_definitions=[fencil], setqs=[]) + # after tracing is done + execute_program(prog, *args, **kwargs) + + +iterator.runtime.fendef_registry[lambda kwargs: "backend" in kwargs] = fendef_tracing diff --git a/src/iterator/util/__init__.py b/src/iterator/util/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/iterator/util/sym_validation.py b/src/iterator/util/sym_validation.py new file mode 100644 index 0000000000..a8432137aa --- /dev/null +++ b/src/iterator/util/sym_validation.py @@ -0,0 +1,51 @@ +from typing import Any, Dict, List, Type +from eve.traits import SymbolTableTrait +from eve.type_definitions import SymbolRef +from eve.visitors import NodeVisitor +import pydantic +from eve.typingx import RootValidatorType, RootValidatorValuesType +from eve import Node + + +def validate_symbol_refs() -> RootValidatorType: + """Validate that symbol refs are found in a symbol table valid at the current scope.""" + + def _impl( + cls: Type[pydantic.BaseModel], values: RootValidatorValuesType + ) -> RootValidatorValuesType: + class SymtableValidator(NodeVisitor): + def __init__(self) -> None: + self.missing_symbols: List[str] = [] + + def visit_Node( + self, node: Node, *, symtable: Dict[str, Any], **kwargs: Any + ) -> None: + for name, metadata in node.__node_children__.items(): + if isinstance(metadata["definition"].type_, type) and issubclass( + metadata["definition"].type_, SymbolRef + ): + if getattr(node, name) and getattr(node, name) not in symtable: + self.missing_symbols.append(getattr(node, name)) + + if isinstance(node, SymbolTableTrait): + symtable = {**symtable, **node.symtable_} + self.generic_visit(node, symtable=symtable, **kwargs) + + @classmethod + def apply(cls, node: Node, *, symtable: Dict[str, Any]) -> List[str]: + instance = cls() + instance.visit(node, symtable=symtable) + return instance.missing_symbols + + missing_symbols = [] + for v in values.values(): + missing_symbols.extend( + SymtableValidator.apply(v, symtable=values["symtable_"]) + ) + + if len(missing_symbols) > 0: + raise ValueError("Symbols {} not found.".format(missing_symbols)) + + return values + + return pydantic.root_validator(allow_reuse=True, skip_on_failure=True)(_impl) diff --git a/src/iterator/utils.py b/src/iterator/utils.py new file mode 100644 index 0000000000..dd7fa908a9 --- /dev/null +++ b/src/iterator/utils.py @@ -0,0 +1,19 @@ +# GT4Py New Semantic Model - GridTools Framework +# +# Copyright (c) 2014-2021, ETH Zurich All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. GT4Py +# New Semantic Model is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the Free +# Software Foundation, either version 3 of the License, or any later version. +# See the LICENSE.txt file at the top-level directory of this distribution for +# a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +def tupelize(tup): + if isinstance(tup, tuple): + return tup + else: + return (tup,) diff --git a/tests/iterator_tests/__init__.py b/tests/iterator_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/iterator_tests/fvm_nabla_setup.py b/tests/iterator_tests/fvm_nabla_setup.py new file mode 100644 index 0000000000..c656393513 --- /dev/null +++ b/tests/iterator_tests/fvm_nabla_setup.py @@ -0,0 +1,204 @@ +# GT4Py New Semantic Model - GridTools Framework +# +# Copyright (c) 2014-2021, ETH Zurich All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. GT4Py +# New Semantic Model is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the Free +# Software Foundation, either version 3 of the License, or any later version. +# See the LICENSE.txt file at the top-level directory of this distribution for +# a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from atlas4py import ( + StructuredGrid, + Topology, + Config, + StructuredMeshGenerator, + functionspace, + build_edges, + build_node_to_edge_connectivity, + build_median_dual_mesh, +) +import numpy as np +import math + + +def assert_close(expected, actual): + assert math.isclose(expected, actual), "expected={}, actual={}".format( + expected, actual + ) + + +class nabla_setup: + def _default_config(): + config = Config() + config["triangulate"] = True + config["angle"] = 20.0 + return config + + def __init__(self, *, grid=StructuredGrid("O32"), config=_default_config()): + mesh = StructuredMeshGenerator(config).generate(grid) + + fs_edges = functionspace.EdgeColumns(mesh, halo=1) + fs_nodes = functionspace.NodeColumns(mesh, halo=1) + + build_edges(mesh) + build_node_to_edge_connectivity(mesh) + build_median_dual_mesh(mesh) + + edges_per_node = max( + [ + mesh.nodes.edge_connectivity.cols(node) + for node in range(0, fs_nodes.size) + ] + ) + + self.mesh = mesh + self.fs_edges = fs_edges + self.fs_nodes = fs_nodes + self.edges_per_node = edges_per_node + + @property + def edges2node_connectivity(self): + return self.mesh.edges.node_connectivity + + @property + def nodes2edge_connectivity(self): + return self.mesh.nodes.edge_connectivity + + @property + def nodes_size(self): + return self.fs_nodes.size + + @property + def edges_size(self): + return self.fs_edges.size + + @property + def sign_field(self): + node2edge_sign = np.zeros((self.nodes_size, self.edges_per_node)) + edge_flags = np.array(self.mesh.edges.flags()) + + def is_pole_edge(e): + return Topology.check(edge_flags[e], Topology.POLE) + + for jnode in range(0, self.nodes_size): + node_edge_con = self.mesh.nodes.edge_connectivity + edge_node_con = self.mesh.edges.node_connectivity + for jedge in range(0, node_edge_con.cols(jnode)): + iedge = node_edge_con[jnode, jedge] + ip1 = edge_node_con[iedge, 0] + if jnode == ip1: + node2edge_sign[jnode, jedge] = 1.0 + else: + node2edge_sign[jnode, jedge] = -1.0 + if is_pole_edge(iedge): + node2edge_sign[jnode, jedge] = 1.0 + return node2edge_sign + + @property + def S_fields(self): + S = np.array(self.mesh.edges.field("dual_normals"), copy=False) + S_MXX = np.zeros((self.edges_size)) + S_MYY = np.zeros((self.edges_size)) + + MXX = 0 + MYY = 1 + + rpi = 2.0 * math.asin(1.0) + radius = 6371.22e03 + deg2rad = 2.0 * rpi / 360.0 + + for i in range(0, self.edges_size): + S_MXX[i] = S[i, MXX] * radius * deg2rad + S_MYY[i] = S[i, MYY] * radius * deg2rad + + assert math.isclose(min(S_MXX), -103437.60479272791) + assert math.isclose(max(S_MXX), 340115.33913622628) + assert math.isclose(min(S_MYY), -2001577.7946404363) + assert math.isclose(max(S_MYY), 2001577.7946404363) + + return S_MXX, S_MYY + + @property + def vol_field(self): + rpi = 2.0 * math.asin(1.0) + radius = 6371.22e03 + deg2rad = 2.0 * rpi / 360.0 + vol_atlas = np.array(self.mesh.nodes.field("dual_volumes"), copy=False) + # dual_volumes 4.6510228700066421 68.891611253882218 12.347560975609632 + assert_close(4.6510228700066421, min(vol_atlas)) + assert_close(68.891611253882218, max(vol_atlas)) + + vol = np.zeros((vol_atlas.size)) + for i in range(0, vol_atlas.size): + vol[i] = vol_atlas[i] * pow(deg2rad, 2) * pow(radius, 2) + # VOL(min/max): 57510668192.214096 851856184496.32886 + assert_close(57510668192.214096, min(vol)) + assert_close(851856184496.32886, max(vol)) + return vol + + @property + def input_field(self): + klevel = 0 + MXX = 0 + MYY = 1 + rpi = 2.0 * math.asin(1.0) + radius = 6371.22e03 + deg2rad = 2.0 * rpi / 360.0 + + zh0 = 2000.0 + zrad = 3.0 * rpi / 4.0 * radius + zeta = rpi / 16.0 * radius + zlatc = 0.0 + zlonc = 3.0 * rpi / 2.0 + + m_rlonlatcr = self.fs_nodes.create_field( + name="m_rlonlatcr", + levels=1, + dtype=np.float64, + variables=self.edges_per_node, + ) + rlonlatcr = np.array(m_rlonlatcr, copy=False) + + m_rcoords = self.fs_nodes.create_field( + name="m_rcoords", levels=1, dtype=np.float64, variables=self.edges_per_node + ) + rcoords = np.array(m_rcoords, copy=False) + + m_rcosa = self.fs_nodes.create_field(name="m_rcosa", levels=1, dtype=np.float64) + rcosa = np.array(m_rcosa, copy=False) + + m_rsina = self.fs_nodes.create_field(name="m_rsina", levels=1, dtype=np.float64) + rsina = np.array(m_rsina, copy=False) + + m_pp = self.fs_nodes.create_field(name="m_pp", levels=1, dtype=np.float64) + rzs = np.array(m_pp, copy=False) + + rcoords_deg = np.array(self.mesh.nodes.field("lonlat")) + + for jnode in range(0, self.nodes_size): + for i in range(0, 2): + rcoords[jnode, klevel, i] = rcoords_deg[jnode, i] * deg2rad + rlonlatcr[jnode, klevel, i] = rcoords[ + jnode, klevel, i + ] # This is not my pattern! + rcosa[jnode, klevel] = math.cos(rlonlatcr[jnode, klevel, MYY]) + rsina[jnode, klevel] = math.sin(rlonlatcr[jnode, klevel, MYY]) + for jnode in range(0, self.nodes_size): + zlon = rlonlatcr[jnode, klevel, MXX] + zdist = math.sin(zlatc) * rsina[jnode, klevel] + math.cos(zlatc) * rcosa[ + jnode, klevel + ] * math.cos(zlon - zlonc) + zdist = radius * math.acos(zdist) + rzs[jnode, klevel] = 0.0 + if zdist < zrad: + rzs[jnode, klevel] = rzs[jnode, klevel] + 0.5 * zh0 * ( + 1.0 + math.cos(rpi * zdist / zrad) + ) * math.pow(math.cos(rpi * zdist / zeta), 2) + + assert_close(0.0000000000000000, min(rzs)) + assert_close(1965.4980340735883, max(rzs)) + return rzs[:, klevel] diff --git a/tests/iterator_tests/hdiff_reference.py b/tests/iterator_tests/hdiff_reference.py new file mode 100644 index 0000000000..e0f3bb31df --- /dev/null +++ b/tests/iterator_tests/hdiff_reference.py @@ -0,0 +1,41 @@ +# GT4Py New Semantic Model - GridTools Framework +# +# Copyright (c) 2014-2021, ETH Zurich All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. GT4Py +# New Semantic Model is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the Free +# Software Foundation, either version 3 of the License, or any later version. +# See the LICENSE.txt file at the top-level directory of this distribution for +# a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import pytest +import numpy as np + + +def hdiff_reference_impl(): + shape = (5, 7, 5) + rng = np.random.default_rng() + inp = rng.normal(size=(shape[0] + 4, shape[1] + 4, shape[2])) + coeff = rng.normal(size=shape) + + lap = 4 * inp[1:-1, 1:-1, :] - ( + inp[2:, 1:-1, :] + inp[:-2, 1:-1, :] + inp[1:-1, 2:, :] + inp[1:-1, :-2, :] + ) + uflx = lap[1:, 1:-1, :] - lap[:-1, 1:-1, :] + flx = np.where(uflx * (inp[2:-1, 2:-2, :] - inp[1:-2, 2:-2, :]) > 0, 0, uflx) + ufly = lap[1:-1, 1:, :] - lap[1:-1, :-1, :] + fly = np.where(ufly * (inp[2:-2, 2:-1, :] - inp[2:-2, 1:-2, :]) > 0, 0, ufly) + out = inp[2:-2, 2:-2, :] - coeff * ( + flx[1:, :, :] - flx[:-1, :, :] + fly[:, 1:, :] - fly[:, :-1, :] + ) + + return inp, coeff, out + + +@pytest.fixture +def hdiff_reference(): + return hdiff_reference_impl() diff --git a/tests/iterator_tests/test_anton_toy.py b/tests/iterator_tests/test_anton_toy.py new file mode 100644 index 0000000000..5b9da41c4d --- /dev/null +++ b/tests/iterator_tests/test_anton_toy.py @@ -0,0 +1,83 @@ +from iterator.builtins import * +from iterator.embedded import np_as_located_field +from iterator.runtime import * +import numpy as np + + +@fundef +def ldif(d): + return lambda inp: deref(shift(d, -1)(inp)) - deref(inp) + + +@fundef +def rdif(d): + # return compose(ldif(d), shift(d, 1)) + return lambda inp: ldif(d)(shift(d, 1)(inp)) + + +@fundef +def dif2(d): + # return compose(ldif(d), lift(rdif(d))) + return lambda inp: ldif(d)(lift(rdif(d))(inp)) + + +i = offset("i") +j = offset("j") + + +@fundef +def lap(inp): + return dif2(i)(inp) + dif2(j)(inp) + + +IDim = CartesianAxis("IDim") +JDim = CartesianAxis("JDim") +KDim = CartesianAxis("KDim") + + +@fendef +def fencil(x, y, z, output, input): + closure( + domain(named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z)), + lap, + [output], + [input], + ) + + +fencil(*([None] * 5), backend="lisp") +fencil(*([None] * 5), backend="cpptoy") + + +def naive_lap(inp): + shape = [inp.shape[0] - 2, inp.shape[1] - 2, inp.shape[2]] + out = np.zeros(shape) + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(0, shape[2]): + out[i, j, k] = -4 * inp[i, j, k] + ( + inp[i + 1, j, k] + inp[i - 1, j, k] + inp[i, j + 1, k] + inp[i, j - 1, k] + ) + return out + + +def test_anton_toy(): + shape = [5, 7, 9] + rng = np.random.default_rng() + inp = np_as_located_field(IDim, JDim, KDim, origin={IDim: 1, JDim: 1, KDim: 0})( + rng.normal(size=(shape[0] + 2, shape[1] + 2, shape[2])), + ) + out = np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)) + ref = naive_lap(inp) + + fencil( + shape[0], + shape[1], + shape[2], + out, + inp, + backend="double_roundtrip", + offset_provider={"i": IDim, "j": JDim}, + ) + + assert np.allclose(out, ref) diff --git a/tests/iterator_tests/test_cartesian_offset_provider.py b/tests/iterator_tests/test_cartesian_offset_provider.py new file mode 100644 index 0000000000..de198b23bd --- /dev/null +++ b/tests/iterator_tests/test_cartesian_offset_provider.py @@ -0,0 +1,79 @@ +from iterator.builtins import * +from iterator.runtime import * + +from iterator.embedded import np_as_located_field +import numpy as np + +I = offset("I") +J = offset("J") +I_loc = CartesianAxis("I_loc") +J_loc = CartesianAxis("J_loc") + + +@fundef +def foo(inp): + return deref(shift(J, 1)(inp)) + + +@fendef(offset_provider={"I": I_loc, "J": J_loc}) +def fencil(output, input): + closure( + domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), + foo, + [output], + [input], + ) + + +@fendef(offset_provider={"I": J_loc, "J": I_loc}) +def fencil_swapped(output, input): + closure( + domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), + foo, + [output], + [input], + ) + + +def test_cartesian_offset_provider(): + inp = np_as_located_field(I_loc, J_loc)(np.asarray([[0, 42], [1, 43]])) + out = np_as_located_field(I_loc, J_loc)(np.asarray([[-1]])) + + fencil(out, inp) + assert out[0][0] == 42 + + fencil_swapped(out, inp) + assert out[0][0] == 1 + + fencil(out, inp, backend="embedded") + assert out[0][0] == 42 + + fencil(out, inp, backend="double_roundtrip") + assert out[0][0] == 42 + + +@fundef +def delay_complete_shift(inp): + return deref(shift(I, J, 1, 1)(inp)) + + +@fendef(offset_provider={"I": J_loc, "J": I_loc}) +def delay_complete_shift_fencil(output, input): + closure( + domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), + delay_complete_shift, + [output], + [input], + ) + + +def test_delay_complete_shift(): + inp = np_as_located_field(I_loc, J_loc)(np.asarray([[0, 42], [1, 43]])) + + out = np_as_located_field(I_loc, J_loc)(np.asarray([[-1]])) + delay_complete_shift_fencil(out, inp) + assert out[0, 0] == 43 + + out = np_as_located_field(I_loc, J_loc)(np.asarray([[-1]])) + delay_complete_shift_fencil(out, inp, backend="embedded", debug=True) + assert out[0, 0] == 43 diff --git a/tests/iterator_tests/test_column_stencil.py b/tests/iterator_tests/test_column_stencil.py new file mode 100644 index 0000000000..b8458ac3cb --- /dev/null +++ b/tests/iterator_tests/test_column_stencil.py @@ -0,0 +1,126 @@ +from iterator.embedded import np_as_located_field +from iterator.runtime import * +from iterator.builtins import * +import numpy as np + +I = offset("I") +K = offset("K") + + +@fundef +def multiply_stencil(inp): + return deref(shift(K, 1, I, 1)(inp)) + + +KDim = CartesianAxis("KDim") +IDim = CartesianAxis("IDim") + + +@fendef(column_axis=KDim) +def fencil(i_size, k_size, inp, out): + closure( + domain(named_range(IDim, 0, i_size), named_range(KDim, 0, k_size)), + multiply_stencil, + [out], + [inp], + ) + + +def test_column_stencil(): + shape = [5, 7] + inp = np_as_located_field(IDim, KDim)( + np.fromfunction(lambda i, k: i * 10 + k, [shape[0] + 1, shape[1] + 1]) + ) + out = np_as_located_field(IDim, KDim)(np.zeros(shape)) + + ref = np.asarray(inp)[1:, 1:] + + fencil(shape[0], shape[1], inp, out, offset_provider={"I": IDim, "K": KDim}) + + assert np.allclose(ref, out) + + +def test_column_stencil_with_k_origin(): + shape = [5, 7] + raw_inp = np.fromfunction(lambda i, k: i * 10 + k, [shape[0] + 1, shape[1] + 2]) + inp = np_as_located_field(IDim, KDim, origin={IDim: 0, KDim: 1})(raw_inp) + out = np_as_located_field(IDim, KDim)(np.zeros(shape)) + + ref = np.asarray(inp)[1:, 2:] + + fencil(shape[0], shape[1], inp, out, offset_provider={"I": IDim, "K": KDim}) + + assert np.allclose(ref, out) + + +@fundef +def sum_scanpass(state, inp): + return if_(is_none(state), deref(inp), state + deref(inp)) + + +@fundef +def ksum(inp): + return scan(sum_scanpass, True, None)(inp) + + +@fendef(column_axis=KDim) +def ksum_fencil(i_size, k_size, inp, out): + closure( + domain(named_range(IDim, 0, i_size), named_range(KDim, 0, k_size)), + ksum, + [out], + [inp], + ) + + +def test_ksum_scan(): + shape = [1, 7] + inp = np_as_located_field(IDim, KDim)(np.asarray([list(range(7))])) + out = np_as_located_field(IDim, KDim)(np.zeros(shape)) + + ref = np.asarray([[0, 1, 3, 6, 10, 15, 21]]) + + ksum_fencil( + shape[0], + shape[1], + inp, + out, + offset_provider={"I": IDim, "K": KDim}, + backend="double_roundtrip", + ) + + assert np.allclose(ref, np.asarray(out)) + + +@fundef +def ksum_back(inp): + return scan(sum_scanpass, False, None)(inp) + + +@fendef(column_axis=KDim) +def ksum_back_fencil(i_size, k_size, inp, out): + closure( + domain(named_range(IDim, 0, i_size), named_range(KDim, 0, k_size)), + ksum_back, + [out], + [inp], + ) + + +def test_ksum_back_scan(): + shape = [1, 7] + inp = np_as_located_field(IDim, KDim)(np.asarray([list(range(7))])) + out = np_as_located_field(IDim, KDim)(np.zeros(shape)) + + ref = np.asarray([[21, 21, 20, 18, 15, 11, 6]]) + + ksum_back_fencil( + shape[0], + shape[1], + inp, + out, + offset_provider={"I": IDim, "K": KDim}, + backend="double_roundtrip", + ) + + assert np.allclose(ref, np.asarray(out)) diff --git a/tests/iterator_tests/test_fvm_nabla.py b/tests/iterator_tests/test_fvm_nabla.py new file mode 100644 index 0000000000..0e35790c59 --- /dev/null +++ b/tests/iterator_tests/test_fvm_nabla.py @@ -0,0 +1,285 @@ +# GT4Py New Semantic Model - GridTools Framework +# +# Copyright (c) 2014-2021, ETH Zurich All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. GT4Py +# New Semantic Model is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the Free +# Software Foundation, either version 3 of the License, or any later version. +# See the LICENSE.txt file at the top-level directory of this distribution for +# a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from iterator.atlas_utils import AtlasTable +from iterator.embedded import NeighborTableOffsetProvider, np_as_located_field +from iterator.runtime import * +from iterator.builtins import * +from iterator import library +from .fvm_nabla_setup import ( + assert_close, + nabla_setup, +) +import numpy as np + + +Vertex = CartesianAxis("Vertex") +Edge = CartesianAxis("Edge") + +V2E = offset("V2E") +E2V = offset("E2V") + + +@fundef +def compute_zavgS(pp, S_M): + zavg = 0.5 * (deref(shift(E2V, 0)(pp)) + deref(shift(E2V, 1)(pp))) + # zavg = 0.5 * reduce(lambda a, b: a + b, 0)(shift(E2V)(pp)) + # zavg = 0.5 * library.sum()(shift(E2V)(pp)) + return deref(S_M) * zavg + + +@fendef +def compute_zavgS_fencil( + n_edges, + out, + pp, + S_M, +): + closure( + domain(named_range(Edge, 0, n_edges)), + compute_zavgS, + [out], + [pp, S_M], + ) + + +@fundef +def compute_pnabla(pp, S_M, sign, vol): + zavgS = lift(compute_zavgS)(pp, S_M) + # pnabla_M = reduce(lambda a, b, c: a + b * c, 0)(shift(V2E)(zavgS), sign) + # pnabla_M = library.sum(lambda a, b: a * b)(shift(V2E)(zavgS), sign) + pnabla_M = library.dot(shift(V2E)(zavgS), sign) + return pnabla_M / deref(vol) + + +@fendef +def nabla( + n_nodes, + out_MXX, + out_MYY, + pp, + S_MXX, + S_MYY, + sign, + vol, +): + # TODO replace by single stencil which returns tuple + closure( + domain(named_range(Vertex, 0, n_nodes)), + compute_pnabla, + [out_MXX], + [pp, S_MXX, sign, vol], + ) + closure( + domain(named_range(Vertex, 0, n_nodes)), + compute_pnabla, + [out_MYY], + [pp, S_MYY, sign, vol], + ) + + +def test_compute_zavgS(): + setup = nabla_setup() + + pp = np_as_located_field(Vertex)(setup.input_field) + S_MXX, S_MYY = tuple(map(np_as_located_field(Edge), setup.S_fields)) + + zavgS = np_as_located_field(Edge)(np.zeros((setup.edges_size))) + + e2v = NeighborTableOffsetProvider(AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2) + + compute_zavgS_fencil( + setup.edges_size, + zavgS, + pp, + S_MXX, + offset_provider={"E2V": e2v}, + ) + + assert_close(-199755464.25741270, min(zavgS)) + assert_close(388241977.58389181, max(zavgS)) + + compute_zavgS_fencil( + setup.edges_size, + zavgS, + pp, + S_MYY, + offset_provider={"E2V": e2v}, + ) + assert_close(-1000788897.3202186, min(zavgS)) + assert_close(1000788897.3202186, max(zavgS)) + + +def test_nabla(): + setup = nabla_setup() + + sign = np_as_located_field(Vertex, V2E)(setup.sign_field) + pp = np_as_located_field(Vertex)(setup.input_field) + S_MXX, S_MYY = tuple(map(np_as_located_field(Edge), setup.S_fields)) + vol = np_as_located_field(Vertex)(setup.vol_field) + + pnabla_MXX = np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) + pnabla_MYY = np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) + + e2v = NeighborTableOffsetProvider(AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2) + v2e = NeighborTableOffsetProvider(AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7) + + nabla( + setup.nodes_size, + pnabla_MXX, + pnabla_MYY, + pp, + S_MXX, + S_MYY, + sign, + vol, + backend="double_roundtrip", + offset_provider={"E2V": e2v, "V2E": v2e}, + ) + + assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) + assert_close(3.5455427772565435e-003, max(pnabla_MXX)) + assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) + assert_close(3.3540113705465301e-003, max(pnabla_MYY)) + + +# @stencil +# def sign(e2v, v2e, node_indices, is_pole_edge): +# node_indices_of_neighbor_edge = node_indices[e2v[v2e]] +# pole_flag_of_neighbor_edges = is_pole_edge[v2e] +# sign_field = if_( +# pole_flag_of_neighbor_edges +# | (broadcast(V2E)(node_indices) == node_indices_of_neighbor_edge[E2V(0)]), +# constant_field(Vertex, V2E)(1.0), +# constant_field(Vertex, V2E)(-1.0), +# ) +# return sign_field + + +# @stencil +# def compute_pnabla_sign(e2v, v2e, pp, S_M, node_indices, is_pole_edge, vol): +# zavgS = compute_zavgS(pp[e2v], S_M)[v2e] +# pnabla_M = sum_reduce(V2E)(zavgS * sign(e2v, v2e, node_indices, is_pole_edge)) + +# return pnabla_M / vol + + +# def nabla_sign( +# e2v, +# v2e, +# pp, +# S_MXX, +# S_MYY, +# node_indices, +# is_pole_edge, +# vol, +# ): +# return ( +# compute_pnabla_sign(e2v, v2e, pp, S_MXX, node_indices, is_pole_edge, vol), +# compute_pnabla_sign(e2v, v2e, pp, S_MYY, node_indices, is_pole_edge, vol), +# ) + + +# def test_nabla_from_sign_stencil(): +# setup = nabla_setup() + +# pp = array_as_field(Vertex)(setup.input_field) +# S_MXX, S_MYY = tuple(map(array_as_field(Edge), setup.S_fields)) +# vol = array_as_field(Vertex)(setup.vol_field) + +# edge_flags = np.array(setup.mesh.edges.flags()) +# is_pole_edge = array_as_field(Edge)( +# np.array([Topology.check(flag, Topology.POLE) for flag in edge_flags]) +# ) + +# node_index_field = index_field(Vertex, range(setup.nodes_size)) + +# e2v = make_sparse_index_field_from_atlas_connectivity( +# setup.edges2node_connectivity, Edge, E2V, Vertex +# ) +# v2e = make_sparse_index_field_from_atlas_connectivity( +# setup.nodes2edge_connectivity, Vertex, V2E, Edge +# ) + +# pnabla_MXX = np.zeros((setup.nodes_size)) +# pnabla_MYY = np.zeros((setup.nodes_size)) + +# print(f"nodes: {setup.nodes_size}") +# print(f"edges: {setup.edges_size}") + +# pnabla_MXX[:], pnabla_MYY[:] = nabla_sign( +# e2v, v2e, pp, S_MXX, S_MYY, node_index_field, is_pole_edge, vol +# ) + +# assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) +# assert_close(3.5455427772565435e-003, max(pnabla_MXX)) +# assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) +# assert_close(3.3540113705465301e-003, max(pnabla_MYY)) + + +# @stencil +# def compute_pnabla_on_nodes(v2e, v2e_e2v_pp, S_M, sign, vol): +# zavgS = S_M[v2e] * 0.5 * (v2e_e2v_pp[E2V(0)] + v2e_e2v_pp[E2V(1)]) +# pnabla_M = sum_reduce(V2E)(zavgS * sign) + +# return pnabla_M / vol + + +# def nabla_on_nodes( +# v2e2v, +# v2e, +# pp, +# S_MXX, +# S_MYY, +# sign, +# vol, +# ): +# v2e_e2v_pp = materialize(pp[v2e2v]) +# return ( +# compute_pnabla_on_nodes(v2e, v2e_e2v_pp, S_MXX, sign, vol), +# compute_pnabla_on_nodes(v2e, v2e_e2v_pp, S_MYY, sign, vol), +# ) + + +# def test_nabla_on_nodes(): +# setup = nabla_setup() + +# sign_acc = array_as_field(Vertex, V2E)(setup.sign_field) +# pp = array_as_field(Vertex)(setup.input_field) +# S_MXX, S_MYY = tuple(map(array_as_field(Edge), setup.S_fields)) +# vol = array_as_field(Vertex)(setup.vol_field) + +# e2v = make_sparse_index_field_from_atlas_connectivity( +# setup.edges2node_connectivity, Edge, E2V, Vertex +# ) +# v2e = make_sparse_index_field_from_atlas_connectivity( +# setup.nodes2edge_connectivity, Vertex, V2E, Edge +# ) +# v2e2v = e2v[ +# v2e +# ] # TODO materialize is broken because I don't preserver optional if materialized as np array + +# pnabla_MXX = np.zeros((setup.nodes_size)) +# pnabla_MYY = np.zeros((setup.nodes_size)) + +# print(f"nodes: {setup.nodes_size}") +# print(f"edges: {setup.edges_size}") + +# pnabla_MXX[:], pnabla_MYY[:] = nabla_on_nodes( +# v2e2v, v2e, pp, S_MXX, S_MYY, sign_acc, vol +# ) + +# assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) +# assert_close(3.5455427772565435e-003, max(pnabla_MXX)) +# assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) +# assert_close(3.3540113705465301e-003, max(pnabla_MYY)) diff --git a/tests/iterator_tests/test_hdiff.py b/tests/iterator_tests/test_hdiff.py new file mode 100644 index 0000000000..0d086bcf6f --- /dev/null +++ b/tests/iterator_tests/test_hdiff.py @@ -0,0 +1,70 @@ +from iterator.builtins import * +from iterator.runtime import * +from iterator.embedded import np_as_located_field +import numpy as np +from .hdiff_reference import hdiff_reference + +I = offset("I") +J = offset("J") + +IDim = CartesianAxis("IDim") +JDim = CartesianAxis("JDim") + + +@fundef +def laplacian(inp): + return -4.0 * deref(inp) + ( + deref(shift(I, 1)(inp)) + + deref(shift(I, -1)(inp)) + + deref(shift(J, 1)(inp)) + + deref(shift(J, -1)(inp)) + ) + + +@fundef +def flux(d): + def flux_impl(inp): + lap = lift(laplacian)(inp) + flux = deref(lap) - deref(shift(d, 1)(lap)) + return if_(flux * (deref(shift(d, 1)(inp)) - deref(inp)) > 0.0, 0.0, flux) + + return flux_impl + + +@fundef +def hdiff_sten(inp, coeff): + flx = lift(flux(I))(inp) + fly = lift(flux(J))(inp) + return deref(inp) - ( + deref(coeff) + * (deref(flx) - deref(shift(I, -1)(flx)) + deref(fly) - deref(shift(J, -1)(fly))) + ) + + +@fendef(offset_provider={"I": IDim, "J": JDim}) +def hdiff(inp, coeff, out, x, y): + closure( + domain(named_range(IDim, 0, x), named_range(JDim, 0, y)), + hdiff_sten, + [out], + [inp, coeff], + ) + + +hdiff(*([None] * 5), backend="lisp") +hdiff(*([None] * 5), backend="cpptoy") + + +def test_hdiff(hdiff_reference): + inp, coeff, out = hdiff_reference + shape = (out.shape[0], out.shape[1]) + + inp_s = np_as_located_field(IDim, JDim, origin={IDim: 2, JDim: 2})(inp[:, :, 0]) + coeff_s = np_as_located_field(IDim, JDim)(coeff[:, :, 0]) + out_s = np_as_located_field(IDim, JDim)(np.zeros_like(coeff[:, :, 0])) + + # hdiff(inp_s, coeff_s, out_s, shape[0], shape[1]) + # hdiff(inp_s, coeff_s, out_s, shape[0], shape[1], backend="embedded") + hdiff(inp_s, coeff_s, out_s, shape[0], shape[1], backend="double_roundtrip") + + assert np.allclose(out[:, :, 0], out_s) diff --git a/tests/iterator_tests/test_horizontal_indirection.py b/tests/iterator_tests/test_horizontal_indirection.py new file mode 100644 index 0000000000..c02710b467 --- /dev/null +++ b/tests/iterator_tests/test_horizontal_indirection.py @@ -0,0 +1,53 @@ +# (defun calc (p_vn input_on_cell) +# (do_some_math +# (deref +# ((if (less (deref p_vn) 0) +# (shift e2c 0) +# (shift e2c 1) +# ) +# input_on_cell +# ) +# ) +# ) +# ) +from numpy.core.numeric import allclose +from iterator.embedded import np_as_located_field +from iterator.runtime import * +from iterator.builtins import * +import numpy as np + +I = offset("I") + + +@fundef +def compute_shift(cond): + return if_(deref(cond) < 0, shift(I, -1), shift(I, 1)) + + +@fundef +def foo(inp, cond): + return deref(compute_shift(cond)(inp)) + + +@fendef +def fencil(size, inp, cond, out): + closure(domain(named_range(IDim, 0, size)), foo, [out], [inp, cond]) + + +IDim = CartesianAxis("IDim") + + +def test_simple_indirection(): + shape = [8] + inp = np_as_located_field(IDim, origin={IDim: 1})(np.asarray(range(shape[0] + 2))) + rng = np.random.default_rng() + cond = np_as_located_field(IDim)(rng.normal(size=shape)) + out = np_as_located_field(IDim)(np.zeros(shape)) + + ref = np.zeros(shape) + for i in range(shape[0]): + ref[i] = inp[i - 1] if cond[i] < 0 else inp[i + 1] + + fencil(shape[0], inp, cond, out, offset_provider={"I": IDim}) + + assert allclose(ref, out) diff --git a/tests/iterator_tests/test_toy_connectivity.py b/tests/iterator_tests/test_toy_connectivity.py new file mode 100644 index 0000000000..fb4e8e3917 --- /dev/null +++ b/tests/iterator_tests/test_toy_connectivity.py @@ -0,0 +1,334 @@ +from dataclasses import field +import numpy as np +from numpy.core.numeric import allclose +from iterator.runtime import * +from iterator.builtins import * +from iterator.embedded import ( + NeighborTableOffsetProvider, + np_as_located_field, + index_field, +) + + +Vertex = CartesianAxis("Vertex") +Edge = CartesianAxis("Edge") + + +# 3x3 periodic edges +# 0 - 1 - 2 - 0 1 2 +# | | | 9 10 11 +# 3 - 4 - 5 - 3 4 5 +# | | | 12 13 14 +# 6 - 7 - 8 - 6 7 8 +# | | | 15 16 17 + + +v2v_arr = np.array( + [ + [1, 3, 2, 6], + [2, 3, 0, 7], + [0, 5, 1, 8], + [4, 6, 5, 0], + [5, 7, 3, 1], + [3, 8, 4, 2], + [7, 0, 8, 3], + [8, 1, 6, 4], + [6, 2, 7, 5], + ] +) + +e2v_arr = np.array( + [ + [0, 1], + [1, 2], + [2, 0], + [3, 4], + [4, 5], + [5, 3], + [6, 7], + [7, 8], + [8, 6], + [0, 3], + [1, 4], + [2, 5], + [3, 6], + [4, 7], + [5, 8], + [6, 0], + [7, 1], + [8, 2], + ] +) + + +# order east, north, west, south (counter-clock wise) +v2e_arr = np.array( + [ + [0, 15, 2, 9], # 0 + [1, 16, 0, 10], + [2, 17, 1, 11], + [3, 9, 5, 12], # 3 + [4, 10, 3, 13], + [5, 11, 4, 14], + [6, 12, 8, 15], # 6 + [7, 13, 6, 16], + [8, 14, 7, 17], + ] +) + +V2E = offset("V2E") +E2V = offset("E2V") + + +@fundef +def sum_edges_to_vertices(in_edges): + return ( + deref(shift(V2E, 0)(in_edges)) + + deref(shift(V2E, 1)(in_edges)) + + deref(shift(V2E, 2)(in_edges)) + + deref(shift(V2E, 3)(in_edges)) + ) + + +@fendef(offset_provider={"V2E": NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}) +def e2v_sum_fencil(in_edges, out_vertices): + closure( + domain(named_range(Vertex, 0, 9)), + sum_edges_to_vertices, + [out_vertices], + [in_edges], + ) + + +def test_sum_edges_to_vertices(): + inp = index_field(Edge) + out = np_as_located_field(Vertex)(np.zeros([9])) + ref = np.asarray(list(sum(row) for row in v2e_arr)) + + e2v_sum_fencil(inp, out, backend="double_roundtrip") + assert allclose(out, ref) + e2v_sum_fencil(None, None, backend="cpptoy") + + +@fundef +def sum_edges_to_vertices_reduce(in_edges): + return reduce(lambda a, b: a + b, 0)(shift(V2E)(in_edges)) + + +@fendef(offset_provider={"V2E": NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}) +def e2v_sum_fencil_reduce(in_edges, out_vertices): + closure( + domain(named_range(Vertex, 0, 9)), + sum_edges_to_vertices_reduce, + [out_vertices], + [in_edges], + ) + + +def test_sum_edges_to_vertices_reduce(): + inp = index_field(Edge) + out = np_as_located_field(Vertex)(np.zeros([9])) + ref = np.asarray(list(sum(row) for row in v2e_arr)) + + e2v_sum_fencil_reduce(None, None, backend="cpptoy") + e2v_sum_fencil_reduce(inp, out, backend="double_roundtrip") + assert allclose(out, ref) + + +@fundef +def sparse_stencil(inp): + return reduce(lambda a, b: a + b, 0)(inp) + + +@fendef +def sparse_fencil(inp, out): + closure( + domain(named_range(Vertex, 0, 9)), + sparse_stencil, + [out], + [inp], + ) + + +def test_sparse_input_field(): + inp = np_as_located_field(Vertex, V2E)(np.asarray([[1, 2, 3, 4]] * 9)) + out = np_as_located_field(Vertex)(np.zeros([9])) + + ref = np.ones([9]) * 10 + + sparse_fencil( + inp, + out, + backend="double_roundtrip", + offset_provider={"V2E": NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + ) + + assert allclose(out, ref) + + +V2V = offset("V2V") + + +def test_sparse_input_field_v2v(): + inp = np_as_located_field(Vertex, V2V)(v2v_arr) + out = np_as_located_field(Vertex)(np.zeros([9])) + + ref = np.asarray(list(sum(row) for row in v2v_arr)) + + sparse_fencil( + inp, + out, + backend="double_roundtrip", + offset_provider={"V2V": NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + ) + + assert allclose(out, ref) + + +@fundef +def deref_stencil(inp): + return deref(shift(V2V, 0)(inp)) + + +@fundef +def lift_stencil(inp): + return deref(shift(V2V, 2)(lift(deref_stencil)(inp))) + + +@fendef +def lift_fencil(inp, out): + closure(domain(named_range(Vertex, 0, 9)), lift_stencil, [out], [inp]) + + +def test_lift(): + inp = index_field(Vertex) + out = np_as_located_field(Vertex)(np.zeros([9])) + ref = np.asarray(np.asarray(range(9))) + + lift_fencil(None, None, backend="cpptoy") + lift_fencil( + inp, + out, + backend="double_roundtrip", + offset_provider={"V2V": NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + ) + assert allclose(out, ref) + + +@fundef +def sparse_shifted_stencil(inp): + return deref(shift(0, 2)(shift(V2V)(inp))) + + +@fendef +def sparse_shifted_fencil(inp, out): + closure( + domain(named_range(Vertex, 0, 9)), + sparse_shifted_stencil, + [out], + [inp], + ) + + +def test_shift_sparse_input_field(): + inp = np_as_located_field(Vertex, V2V)(v2v_arr) + out = np_as_located_field(Vertex)(np.zeros([9])) + ref = np.asarray(np.asarray(range(9))) + + sparse_shifted_fencil( + inp, + out, + backend="double_roundtrip", + offset_provider={"V2V": NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + ) + + assert allclose(out, ref) + + +@fundef +def shift_shift_stencil2(inp): + return deref(shift(V2E, 3)(shift(E2V, 1)(inp))) + + +@fundef +def shift_sparse_stencil2(inp): + return deref(shift(1, 3)(shift(V2E)(inp))) + + +@fendef +def sparse_shifted_fencil2(inp_sparse, inp, out1, out2): + closure( + domain(named_range(Vertex, 0, 9)), + shift_shift_stencil2, + [out1], + [inp], + ) + closure( + domain(named_range(Vertex, 0, 9)), + shift_sparse_stencil2, + [out2], + [inp_sparse], + ) + + +def test_shift_sparse_input_field2(): + inp = index_field(Vertex) + inp_sparse = np_as_located_field(Edge, E2V)(e2v_arr) + out1 = np_as_located_field(Vertex)(np.zeros([9])) + out2 = np_as_located_field(Vertex)(np.zeros([9])) + + sparse_shifted_fencil2( + inp_sparse, + inp, + out1, + out2, + backend="double_roundtrip", + offset_provider={ + "E2V": NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), + "V2E": NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), + }, + ) + + assert allclose(out1, out2) + + +@fundef +def sparse_shifted_stencil_reduce(inp): + def sum_(a, b): + return a + b + + # return deref(shift(V2V, 0)(lift(deref)(shift(0)(inp)))) + return reduce(sum_, 0)(shift(V2V)(lift(reduce(sum_, 0))(inp))) + + +@fendef +def sparse_shifted_fencil_reduce(inp, out): + closure( + domain(named_range(Vertex, 0, 9)), + sparse_shifted_stencil_reduce, + [out], + [inp], + ) + + +def test_shift_sparse_input_field(): + inp = np_as_located_field(Vertex, V2V)(v2v_arr) + out = np_as_located_field(Vertex)(np.zeros([9])) + + ref = [] + for row in v2v_arr: + elem_sum = 0 + for neigh in row: + elem_sum += sum(v2v_arr[neigh]) + ref.append(elem_sum) + + ref = np.asarray(ref) + + sparse_shifted_fencil_reduce( + inp, + out, + backend="double_roundtrip", + offset_provider={"V2V": NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + ) + + assert allclose(np.asarray(out), ref) diff --git a/tests/iterator_tests/test_vertical_advection.py b/tests/iterator_tests/test_vertical_advection.py new file mode 100644 index 0000000000..0f649f278f --- /dev/null +++ b/tests/iterator_tests/test_vertical_advection.py @@ -0,0 +1,124 @@ +from iterator.builtins import * +from iterator.embedded import np_as_located_field +from iterator.runtime import * +import numpy as np +import pytest + + +@fundef +def tridiag_forward(state, a, b, c, d): + # not tracable + # if is_none(state): + # cp_k = deref(c) / deref(b) + # dp_k = deref(d) / deref(b) + # else: + # cp_km1, dp_km1 = state + # cp_k = deref(c) / (deref(b) - deref(a) * cp_km1) + # dp_k = (deref(d) - deref(a) * dp_km1) / (deref(b) - deref(a) * cp_km1) + # return make_tuple(cp_k, dp_k) + + # variant a + # return if_( + # is_none(state), + # make_tuple(deref(c) / deref(b), deref(d) / deref(b)), + # make_tuple( + # deref(c) / (deref(b) - deref(a) * nth(0, state)), + # (deref(d) - deref(a) * nth(1, state)) + # / (deref(b) - deref(a) * nth(0, state)), + # ), + # ) + + # variant b + def initial(): + return make_tuple(deref(c) / deref(b), deref(d) / deref(b)) + + def step(): + return make_tuple( + deref(c) / (deref(b) - deref(a) * nth(0, state)), + (deref(d) - deref(a) * nth(1, state)) / (deref(b) - deref(a) * nth(0, state)), + ) + + return if_(is_none(state), initial, step)() + + +@fundef +def tridiag_backward(x_kp1, cp, dp): + # if is_none(x_kp1): + # x_k = deref(dp) + # else: + # x_k = deref(dp) - deref(cp) * x_kp1 + # return x_k + return if_(is_none(x_kp1), deref(dp), deref(dp) - deref(cp) * x_kp1) + + +@fundef +def solve_tridiag(a, b, c, d): + tup = lift(scan(tridiag_forward, True, None))(a, b, c, d) + cp = nth(0, tup) + dp = nth(1, tup) + return scan(tridiag_backward, False, None)(cp, dp) + + +@pytest.fixture +def tridiag_reference(): + shape = (3, 7, 5) + rng = np.random.default_rng() + a = rng.normal(size=shape) + b = rng.normal(size=shape) * 2 + c = rng.normal(size=shape) + d = rng.normal(size=shape) + + matrices = np.zeros(shape + shape[-1:]) + i = np.arange(shape[2]) + matrices[:, :, i[1:], i[:-1]] = a[:, :, 1:] + matrices[:, :, i, i] = b + matrices[:, :, i[:-1], i[1:]] = c[:, :, :-1] + x = np.linalg.solve(matrices, d) + return a, b, c, d, x + + +IDim = CartesianAxis("IDim") +JDim = CartesianAxis("JDim") +KDim = CartesianAxis("KDim") + + +@fendef +def fen_solve_tridiag(i_size, j_size, k_size, a, b, c, d, x): + closure( + domain( + named_range(IDim, 0, i_size), + named_range(JDim, 0, j_size), + named_range(KDim, 0, k_size), + ), + solve_tridiag, + [x], + [a, b, c, d], + ) + + +def test_tridiag(tridiag_reference): + a, b, c, d, x = tridiag_reference + shape = a.shape + as_3d_field = np_as_located_field(IDim, JDim, KDim) + a_s = as_3d_field(a) + b_s = as_3d_field(b) + c_s = as_3d_field(c) + d_s = as_3d_field(d) + x_s = as_3d_field(np.zeros_like(x)) + + fen_solve_tridiag( + shape[0], + shape[1], + shape[2], + a_s, + b_s, + c_s, + d_s, + x_s, + offset_provider={}, + column_axis=KDim, + backend="double_roundtrip", + # debug=True, + ) + + assert np.allclose(x, np.asarray(x_s)) From d9046e4f7180948ca47a46bf3427ec10230cfe76 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 6 Aug 2021 09:36:39 +0200 Subject: [PATCH 2/6] add more doc --- src/iterator/ARCHITECTURE.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/iterator/ARCHITECTURE.md b/src/iterator/ARCHITECTURE.md index decd642db9..28daa3728e 100644 --- a/src/iterator/ARCHITECTURE.md +++ b/src/iterator/ARCHITECTURE.md @@ -58,3 +58,10 @@ Generates from the IR an aquivalent Python iterator view program which is then e ### Double roundtrip Generates the Python iterator view program, traces it again, generates again and executes. Ensures that the generated Python code can still be traced. While the original program might be different from the generated program (e.g. `+` will be converted to `plus()` builtin). The programs from the embedded and double roundtrip backends should be identical. + +## Adding a new builtin + +Currently there are 4 places where a new builtin needs to be added +- `builtin.py`: for dispatching to an actual implementation +- `embedded.py` and `tracing.py`: for the respective implementation +- `ir.py`: we check for consistent use of symbols, therefore if a `FunCall` to the new builtin is used, it needs to be available in the symbol table. From c26662c82a4ea2c0858160b8f10e0c2ad3f589ec Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 15 Oct 2021 14:39:34 +0200 Subject: [PATCH 3/6] execute shift once complete --- .gitignore | 1 + src/iterator/embedded.py | 49 +++++++++------- tests/iterator_tests/test_toy_connectivity.py | 57 +++++++++++++++++-- 3 files changed, 79 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index d35386ed39..d6289092fe 100644 --- a/.gitignore +++ b/.gitignore @@ -107,6 +107,7 @@ venv.bak/ /site # DaCe +.dacecache *.sdfg .dace.conf diff --git a/src/iterator/embedded.py b/src/iterator/embedded.py index 2344f5f5ba..3201426b8f 100644 --- a/src/iterator/embedded.py +++ b/src/iterator/embedded.py @@ -268,11 +268,11 @@ def execute_shift(pos, tag, index, *, offset_provider): # shift(tag)(inp) -> incomplete shift # shift(index)(shift(tag)(inp)) -> full shift # Therefore the following transformation holds -# shift(e2c,0)(shift(v2c,2)(cell_field)) -# = shift(0)(shift(e2c)(shift(2)(shift(v2c)(cell_field)))) -# = shift(v2c, 2, e2c, 0)(cell_field) -# = shift(v2c,e2c,2,0)(cell_field) <-- v2c,e2c twice incomplete shift -# = shift(2,0)(shift(v2c,e2c)(cell_field)) +# shift(e2v,0)(shift(c2e,2)(cell_field)) +# = shift(0)(shift(e2v)(shift(2)(shift(c2e)(cell_field)))) +# = shift(c2e, 2, e2v, 0)(cell_field) +# = shift(c2e,e2v,2,0)(cell_field) <-- v2c,e2c twice incomplete shift +# = shift(2,0)(shift(c2e,e2v)(cell_field)) # for implementations it means everytime we have an index, we can "execute" a concrete shift def group_offsets(*offsets): tag_stack = [] @@ -286,15 +286,16 @@ def group_offsets(*offsets): else: tag_stack.append(offset) else: - assert not tag_stack - index_stack.append(offset) + if tag_stack: + tag = tag_stack.pop(0) + complete_offsets.append((tag, offset)) + # assert not tag_stack + else: + index_stack.append(offset) return complete_offsets, tag_stack -def shift_position(pos, *offsets, offset_provider): - complete_offsets, open_offsets = group_offsets(*offsets) - # assert not open_offsets # TODO enable this, check failing test and make everything saver - +def shift_position(pos, *complete_offsets, offset_provider): new_pos = pos.copy() for tag, index in complete_offsets: new_pos = execute_shift(new_pos, tag, index, offset_provider=offset_provider) @@ -308,33 +309,37 @@ def get_open_offsets(*offsets): class MDIterator: - def __init__(self, field, pos, *, offsets=[], offset_provider, column_axis=None) -> None: + def __init__( + self, field, pos, *, incomplete_offsets=[], offset_provider, column_axis=None + ) -> None: self.field = field self.pos = pos - self.offsets = offsets + self.incomplete_offsets = incomplete_offsets self.offset_provider = offset_provider self.column_axis = column_axis def shift(self, *offsets): + complete_offsets, open_offsets = group_offsets(*self.incomplete_offsets, *offsets) return MDIterator( self.field, - self.pos, - offsets=[*offsets, *self.offsets], + shift_position(self.pos, *complete_offsets, offset_provider=self.offset_provider), + incomplete_offsets=open_offsets, offset_provider=self.offset_provider, column_axis=self.column_axis, ) def max_neighbors(self): - open_offsets = get_open_offsets(*self.offsets) - assert open_offsets - assert isinstance(self.offset_provider[open_offsets[0].value], NeighborTableOffsetProvider) - return self.offset_provider[open_offsets[0].value].max_neighbors + assert self.incomplete_offsets + assert isinstance( + self.offset_provider[self.incomplete_offsets[0].value], NeighborTableOffsetProvider + ) + return self.offset_provider[self.incomplete_offsets[0].value].max_neighbors def is_none(self): - return shift_position(self.pos, *self.offsets, offset_provider=self.offset_provider) is None + return self.pos is None def deref(self): - shifted_pos = shift_position(self.pos, *self.offsets, offset_provider=self.offset_provider) + shifted_pos = self.pos.copy() if not all(axis in shifted_pos.keys() for axis in self.field.axises): raise IndexError("Iterator position doesn't point to valid location for its field.") @@ -362,7 +367,7 @@ def make_in_iterator(inp, pos, offset_provider, *, column_axis): return MDIterator( inp, new_pos, - offsets=[*sparse_dimensions], + incomplete_offsets=[*sparse_dimensions], offset_provider=offset_provider, column_axis=column_axis, ) diff --git a/tests/iterator_tests/test_toy_connectivity.py b/tests/iterator_tests/test_toy_connectivity.py index fb4e8e3917..453ff17919 100644 --- a/tests/iterator_tests/test_toy_connectivity.py +++ b/tests/iterator_tests/test_toy_connectivity.py @@ -12,17 +12,32 @@ Vertex = CartesianAxis("Vertex") Edge = CartesianAxis("Edge") +Cell = CartesianAxis("Cell") -# 3x3 periodic edges +# 3x3 periodic edges cells # 0 - 1 - 2 - 0 1 2 -# | | | 9 10 11 +# | | | 9 10 11 0 1 2 # 3 - 4 - 5 - 3 4 5 -# | | | 12 13 14 +# | | | 12 13 14 3 4 5 # 6 - 7 - 8 - 6 7 8 -# | | | 15 16 17 +# | | | 15 16 17 6 7 8 +c2e_arr = np.array( + [ + [0, 10, 3, 9], # 0 + [1, 11, 4, 10], + [2, 9, 5, 11], + [3, 13, 6, 12], # 3 + [4, 14, 7, 13], + [5, 12, 8, 14], + [6, 16, 0, 15], # 6 + [7, 17, 1, 16], + [8, 15, 2, 17], + ] +) + v2v_arr = np.array( [ [1, 3, 2, 6], @@ -78,6 +93,7 @@ V2E = offset("V2E") E2V = offset("E2V") +C2E = offset("C2E") @fundef @@ -135,6 +151,35 @@ def test_sum_edges_to_vertices_reduce(): assert allclose(out, ref) +@fundef +def first_vertex_neigh_of_first_edge_neigh_of_cells(in_vertices): + return deref(shift(E2V, 0)(shift(C2E, 0)(in_vertices))) + + +@fendef( + offset_provider={ + "E2V": NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), + "C2E": NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4), + } +) +def first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(in_vertices, out_cells): + closure( + domain(named_range(Cell, 0, 9)), + first_vertex_neigh_of_first_edge_neigh_of_cells, + [out_cells], + [in_vertices], + ) + + +def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(): + inp = index_field(Vertex) + out = np_as_located_field(Cell)(np.zeros([9])) + ref = np.asarray(list(v2e_arr[c[0]][0] for c in c2e_arr)) + + first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(inp, out, backend="double_roundtrip") + assert allclose(out, ref) + + @fundef def sparse_stencil(inp): return reduce(lambda a, b: a + b, 0)(inp) @@ -247,12 +292,12 @@ def test_shift_sparse_input_field(): @fundef def shift_shift_stencil2(inp): - return deref(shift(V2E, 3)(shift(E2V, 1)(inp))) + return deref(shift(E2V, 1)(shift(V2E, 3)(inp))) @fundef def shift_sparse_stencil2(inp): - return deref(shift(1, 3)(shift(V2E)(inp))) + return deref(shift(3, 1)(shift(V2E)(inp))) @fendef From aac8839eca9f43430e7c91aae16ded563ce04035 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 18 Oct 2021 11:30:24 +0200 Subject: [PATCH 4/6] cleanup lifting and fix sparse example --- src/iterator/embedded.py | 47 +++---------------- tests/iterator_tests/test_toy_connectivity.py | 2 +- 2 files changed, 8 insertions(+), 41 deletions(-) diff --git a/src/iterator/embedded.py b/src/iterator/embedded.py index 3201426b8f..87fe81d8e3 100644 --- a/src/iterator/embedded.py +++ b/src/iterator/embedded.py @@ -70,7 +70,7 @@ def __getitem__(self, index): return wrap_iterator(offsets=self.offsets, elem=index) def shift(self, *offsets): - return wrap_iterator(offsets=[*offsets, *self.offsets], elem=self.elem) + return wrap_iterator(offsets=[*self.offsets, *offsets], elem=self.elem) def max_neighbors(self): # TODO cleanup, test edge cases @@ -83,32 +83,7 @@ def max_neighbors(self): return args[0].offset_provider[open_offsets[0].value].max_neighbors def deref(self): - class DelayedIterator: - def __init__(self, wrapped_iterator, lifted_offsets, *, offsets=[]) -> None: - self.wrapped_iterator = wrapped_iterator - self.lifted_offsets = lifted_offsets - self.offsets = offsets - - def is_none(self): - shifted = self.wrapped_iterator.shift(*self.lifted_offsets, *self.offsets) - return shifted.is_none() - - def max_neighbors(self): - shifted = self.wrapped_iterator.shift(*self.lifted_offsets, *self.offsets) - return shifted.max_neighbors() - - def shift(self, *offsets): - return DelayedIterator( - self.wrapped_iterator, - self.lifted_offsets, - offsets=[*offsets, *self.offsets], - ) - - def deref(self): - shifted = self.wrapped_iterator.shift(*self.lifted_offsets, *self.offsets) - return shifted.deref() - - shifted_args = tuple(map(lambda arg: DelayedIterator(arg, self.offsets), args)) + shifted_args = tuple(map(lambda arg: arg.shift(*self.offsets), args)) if any(shifted_arg.is_none() for shifted_arg in shifted_args): return None @@ -276,22 +251,14 @@ def execute_shift(pos, tag, index, *, offset_provider): # for implementations it means everytime we have an index, we can "execute" a concrete shift def group_offsets(*offsets): tag_stack = [] - index_stack = [] complete_offsets = [] for offset in offsets: if not isinstance(offset, int): - if index_stack: - index = index_stack.pop(0) - complete_offsets.append((offset, index)) - else: - tag_stack.append(offset) + tag_stack.append(offset) else: - if tag_stack: - tag = tag_stack.pop(0) - complete_offsets.append((tag, offset)) - # assert not tag_stack - else: - index_stack.append(offset) + assert tag_stack + tag = tag_stack.pop(0) + complete_offsets.append((tag, offset)) return complete_offsets, tag_stack @@ -468,7 +435,7 @@ def index_field(axis): @iterator.builtins.shift.register(EMBEDDED) def shift(*offsets): def impl(iter): - return iter.shift(*reversed(offsets)) + return iter.shift(*offsets) return impl diff --git a/tests/iterator_tests/test_toy_connectivity.py b/tests/iterator_tests/test_toy_connectivity.py index 453ff17919..226eac1eb7 100644 --- a/tests/iterator_tests/test_toy_connectivity.py +++ b/tests/iterator_tests/test_toy_connectivity.py @@ -297,7 +297,7 @@ def shift_shift_stencil2(inp): @fundef def shift_sparse_stencil2(inp): - return deref(shift(3, 1)(shift(V2E)(inp))) + return deref(shift(1, 3)(shift(V2E)(inp))) @fendef From 7a4bb1f9d148d53186259d5217fd2e65f0087974 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 20 Oct 2021 22:22:26 +0200 Subject: [PATCH 5/6] fvm sign field --- src/iterator/builtins.py | 16 +- src/iterator/embedded.py | 12 ++ src/iterator/ir.py | 2 + src/iterator/tracing.py | 14 ++ tests/iterator_tests/fvm_nabla_setup.py | 31 ++-- tests/iterator_tests/test_fvm_nabla.py | 212 +++++++++--------------- 6 files changed, 140 insertions(+), 147 deletions(-) diff --git a/src/iterator/builtins.py b/src/iterator/builtins.py index 5e2ddecef3..5dfd0a18fa 100644 --- a/src/iterator/builtins.py +++ b/src/iterator/builtins.py @@ -11,10 +11,12 @@ "named_range", "compose", "if_", + "or_", "minus", "plus", "mul", "div", + "eq", "greater", "make_tuple", "nth", @@ -78,6 +80,11 @@ def if_(*args): raise BackendNotSelectedError() +@builtin_dispatch +def or_(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def minus(*args): raise BackendNotSelectedError() @@ -98,6 +105,11 @@ def div(*args): raise BackendNotSelectedError() +@builtin_dispatch +def eq(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def greater(*args): raise BackendNotSelectedError() @@ -105,9 +117,9 @@ def greater(*args): @builtin_dispatch def make_tuple(*args): - raise BackendNotSelectedError + raise BackendNotSelectedError() @builtin_dispatch def nth(*args): - raise BackendNotSelectedError + raise BackendNotSelectedError() diff --git a/src/iterator/embedded.py b/src/iterator/embedded.py index 87fe81d8e3..679d065b94 100644 --- a/src/iterator/embedded.py +++ b/src/iterator/embedded.py @@ -13,11 +13,13 @@ domain, named_range, if_, + or_, minus, plus, mul, div, greater, + eq, nth, make_tuple, ) @@ -47,6 +49,11 @@ def if_(cond, t, f): return t if cond else f +@or_.register(EMBEDDED) +def or_(a, b): + return a or b + + @nth.register(EMBEDDED) def nth(i, tup): return tup[i] @@ -195,6 +202,11 @@ def div(first, second): return first / second +@eq.register(EMBEDDED) +def eq(first, second): + return first == second + + @greater.register(EMBEDDED) def greater(first, second): return first > second diff --git a/src/iterator/ir.py b/src/iterator/ir.py index 856becdea7..f2214a4dd1 100644 --- a/src/iterator/ir.py +++ b/src/iterator/ir.py @@ -108,9 +108,11 @@ class Program(Node, SymbolTableTrait): "minus", "mul", "div", + "eq", "greater", "less", "if_", + "or_", ] ) _validate_symbol_refs = validate_symbol_refs() diff --git a/src/iterator/tracing.py b/src/iterator/tracing.py index 74e4e1d4c6..cca543ad87 100644 --- a/src/iterator/tracing.py +++ b/src/iterator/tracing.py @@ -58,6 +58,10 @@ def __truediv__(self, other): def __sub__(self, other): return FunCall(fun=SymRef(id="minus"), args=[self, make_node(other)]) + @monkeypatch_method(Expr) + def __eq__(self, other): + return FunCall(fun=SymRef(id="eq"), args=[self, make_node(other)]) + @monkeypatch_method(Expr) def __gt__(self, other): return FunCall(fun=SymRef(id="greater"), args=[self, make_node(other)]) @@ -154,6 +158,11 @@ def if_(*args): return _f("if_", *args) +@iterator.builtins.or_.register(TRACING) +def or_(*args): + return _f("or_", *args) + + # shift promotes its arguments to literals, therefore special @iterator.builtins.shift.register(TRACING) def shift(*offsets): @@ -181,6 +190,11 @@ def div(*args): return _f("div", *args) +@iterator.builtins.eq.register(TRACING) +def eq(*args): + return _f("eq", *args) + + @iterator.builtins.greater.register(TRACING) def greater(*args): return _f("greater", *args) diff --git a/tests/iterator_tests/fvm_nabla_setup.py b/tests/iterator_tests/fvm_nabla_setup.py index c656393513..b8f01d6cee 100644 --- a/tests/iterator_tests/fvm_nabla_setup.py +++ b/tests/iterator_tests/fvm_nabla_setup.py @@ -26,9 +26,7 @@ def assert_close(expected, actual): - assert math.isclose(expected, actual), "expected={}, actual={}".format( - expected, actual - ) + assert math.isclose(expected, actual), "expected={}, actual={}".format(expected, actual) class nabla_setup: @@ -49,10 +47,7 @@ def __init__(self, *, grid=StructuredGrid("O32"), config=_default_config()): build_median_dual_mesh(mesh) edges_per_node = max( - [ - mesh.nodes.edge_connectivity.cols(node) - for node in range(0, fs_nodes.size) - ] + [mesh.nodes.edge_connectivity.cols(node) for node in range(0, fs_nodes.size)] ) self.mesh = mesh @@ -76,14 +71,24 @@ def nodes_size(self): def edges_size(self): return self.fs_edges.size + @staticmethod + def _is_pole_edge(e, edge_flags): + return Topology.check(edge_flags[e], Topology.POLE) + + @property + def is_pole_edge_field(self): + edge_flags = np.array(self.mesh.edges.flags()) + + pole_edge_field = np.zeros((self.edges_size,)) + for e in range(self.edges_size): + pole_edge_field[e] = self._is_pole_edge(e, edge_flags) + return edge_flags + @property def sign_field(self): node2edge_sign = np.zeros((self.nodes_size, self.edges_per_node)) edge_flags = np.array(self.mesh.edges.flags()) - def is_pole_edge(e): - return Topology.check(edge_flags[e], Topology.POLE) - for jnode in range(0, self.nodes_size): node_edge_con = self.mesh.nodes.edge_connectivity edge_node_con = self.mesh.edges.node_connectivity @@ -94,7 +99,7 @@ def is_pole_edge(e): node2edge_sign[jnode, jedge] = 1.0 else: node2edge_sign[jnode, jedge] = -1.0 - if is_pole_edge(iedge): + if self._is_pole_edge(iedge, edge_flags): node2edge_sign[jnode, jedge] = 1.0 return node2edge_sign @@ -182,9 +187,7 @@ def input_field(self): for jnode in range(0, self.nodes_size): for i in range(0, 2): rcoords[jnode, klevel, i] = rcoords_deg[jnode, i] * deg2rad - rlonlatcr[jnode, klevel, i] = rcoords[ - jnode, klevel, i - ] # This is not my pattern! + rlonlatcr[jnode, klevel, i] = rcoords[jnode, klevel, i] # This is not my pattern! rcosa[jnode, klevel] = math.cos(rlonlatcr[jnode, klevel, MYY]) rsina[jnode, klevel] = math.sin(rlonlatcr[jnode, klevel, MYY]) for jnode in range(0, self.nodes_size): diff --git a/tests/iterator_tests/test_fvm_nabla.py b/tests/iterator_tests/test_fvm_nabla.py index 0e35790c59..0c18b6b4d8 100644 --- a/tests/iterator_tests/test_fvm_nabla.py +++ b/tests/iterator_tests/test_fvm_nabla.py @@ -11,8 +11,10 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import itertools + from iterator.atlas_utils import AtlasTable -from iterator.embedded import NeighborTableOffsetProvider, np_as_located_field +from iterator.embedded import NeighborTableOffsetProvider, index_field, np_as_located_field from iterator.runtime import * from iterator.builtins import * from iterator import library @@ -153,133 +155,81 @@ def test_nabla(): assert_close(3.3540113705465301e-003, max(pnabla_MYY)) -# @stencil -# def sign(e2v, v2e, node_indices, is_pole_edge): -# node_indices_of_neighbor_edge = node_indices[e2v[v2e]] -# pole_flag_of_neighbor_edges = is_pole_edge[v2e] -# sign_field = if_( -# pole_flag_of_neighbor_edges -# | (broadcast(V2E)(node_indices) == node_indices_of_neighbor_edge[E2V(0)]), -# constant_field(Vertex, V2E)(1.0), -# constant_field(Vertex, V2E)(-1.0), -# ) -# return sign_field - - -# @stencil -# def compute_pnabla_sign(e2v, v2e, pp, S_M, node_indices, is_pole_edge, vol): -# zavgS = compute_zavgS(pp[e2v], S_M)[v2e] -# pnabla_M = sum_reduce(V2E)(zavgS * sign(e2v, v2e, node_indices, is_pole_edge)) - -# return pnabla_M / vol - - -# def nabla_sign( -# e2v, -# v2e, -# pp, -# S_MXX, -# S_MYY, -# node_indices, -# is_pole_edge, -# vol, -# ): -# return ( -# compute_pnabla_sign(e2v, v2e, pp, S_MXX, node_indices, is_pole_edge, vol), -# compute_pnabla_sign(e2v, v2e, pp, S_MYY, node_indices, is_pole_edge, vol), -# ) - - -# def test_nabla_from_sign_stencil(): -# setup = nabla_setup() - -# pp = array_as_field(Vertex)(setup.input_field) -# S_MXX, S_MYY = tuple(map(array_as_field(Edge), setup.S_fields)) -# vol = array_as_field(Vertex)(setup.vol_field) - -# edge_flags = np.array(setup.mesh.edges.flags()) -# is_pole_edge = array_as_field(Edge)( -# np.array([Topology.check(flag, Topology.POLE) for flag in edge_flags]) -# ) - -# node_index_field = index_field(Vertex, range(setup.nodes_size)) - -# e2v = make_sparse_index_field_from_atlas_connectivity( -# setup.edges2node_connectivity, Edge, E2V, Vertex -# ) -# v2e = make_sparse_index_field_from_atlas_connectivity( -# setup.nodes2edge_connectivity, Vertex, V2E, Edge -# ) - -# pnabla_MXX = np.zeros((setup.nodes_size)) -# pnabla_MYY = np.zeros((setup.nodes_size)) - -# print(f"nodes: {setup.nodes_size}") -# print(f"edges: {setup.edges_size}") - -# pnabla_MXX[:], pnabla_MYY[:] = nabla_sign( -# e2v, v2e, pp, S_MXX, S_MYY, node_index_field, is_pole_edge, vol -# ) - -# assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) -# assert_close(3.5455427772565435e-003, max(pnabla_MXX)) -# assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) -# assert_close(3.3540113705465301e-003, max(pnabla_MYY)) - - -# @stencil -# def compute_pnabla_on_nodes(v2e, v2e_e2v_pp, S_M, sign, vol): -# zavgS = S_M[v2e] * 0.5 * (v2e_e2v_pp[E2V(0)] + v2e_e2v_pp[E2V(1)]) -# pnabla_M = sum_reduce(V2E)(zavgS * sign) - -# return pnabla_M / vol - - -# def nabla_on_nodes( -# v2e2v, -# v2e, -# pp, -# S_MXX, -# S_MYY, -# sign, -# vol, -# ): -# v2e_e2v_pp = materialize(pp[v2e2v]) -# return ( -# compute_pnabla_on_nodes(v2e, v2e_e2v_pp, S_MXX, sign, vol), -# compute_pnabla_on_nodes(v2e, v2e_e2v_pp, S_MYY, sign, vol), -# ) - - -# def test_nabla_on_nodes(): -# setup = nabla_setup() - -# sign_acc = array_as_field(Vertex, V2E)(setup.sign_field) -# pp = array_as_field(Vertex)(setup.input_field) -# S_MXX, S_MYY = tuple(map(array_as_field(Edge), setup.S_fields)) -# vol = array_as_field(Vertex)(setup.vol_field) - -# e2v = make_sparse_index_field_from_atlas_connectivity( -# setup.edges2node_connectivity, Edge, E2V, Vertex -# ) -# v2e = make_sparse_index_field_from_atlas_connectivity( -# setup.nodes2edge_connectivity, Vertex, V2E, Edge -# ) -# v2e2v = e2v[ -# v2e -# ] # TODO materialize is broken because I don't preserver optional if materialized as np array - -# pnabla_MXX = np.zeros((setup.nodes_size)) -# pnabla_MYY = np.zeros((setup.nodes_size)) - -# print(f"nodes: {setup.nodes_size}") -# print(f"edges: {setup.edges_size}") - -# pnabla_MXX[:], pnabla_MYY[:] = nabla_on_nodes( -# v2e2v, v2e, pp, S_MXX, S_MYY, sign_acc, vol -# ) - -# assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) -# assert_close(3.5455427772565435e-003, max(pnabla_MXX)) -# assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) -# assert_close(3.3540113705465301e-003, max(pnabla_MYY)) +@fundef +def sign(node_indices, is_pole_edge): + node_index = deref(node_indices) + + @fundef + def sign_impl(node_index): + def impl2(node_indices, is_pole_edge): + return if_( + or_(deref(is_pole_edge), node_index == deref(shift(E2V, 0)(node_indices))), + 1.0, + -1.0, + ) + + return impl2 + + return shift(V2E)(lift(sign_impl(node_index))(node_indices, is_pole_edge)) + + +@fundef +def compute_pnabla_sign(pp, S_M, vol, node_index, is_pole_edge): + zavgS = lift(compute_zavgS)(pp, S_M) + pnabla_M = library.dot(shift(V2E)(zavgS), sign(node_index, is_pole_edge)) + + return pnabla_M / deref(vol) + + +@fendef +def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_pole_edge): + # TODO replace by single stencil which returns tuple + closure( + domain(named_range(Vertex, 0, n_nodes)), + compute_pnabla_sign, + [out_MXX], + [pp, S_MXX, vol, node_index, is_pole_edge], + ) + closure( + domain(named_range(Vertex, 0, n_nodes)), + compute_pnabla_sign, + [out_MYY], + [pp, S_MYY, vol, node_index, is_pole_edge], + ) + + +def test_nabla_sign(): + setup = nabla_setup() + + # sign = np_as_located_field(Vertex, V2E)(setup.sign_field) + is_pole_edge = np_as_located_field(Edge)(setup.is_pole_edge_field) + pp = np_as_located_field(Vertex)(setup.input_field) + S_MXX, S_MYY = tuple(map(np_as_located_field(Edge), setup.S_fields)) + vol = np_as_located_field(Vertex)(setup.vol_field) + + pnabla_MXX = np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) + pnabla_MYY = np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) + + e2v = NeighborTableOffsetProvider(AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2) + v2e = NeighborTableOffsetProvider(AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7) + + nabla_sign( + setup.nodes_size, + pnabla_MXX, + pnabla_MYY, + pp, + S_MXX, + S_MYY, + vol, + index_field(Vertex), + is_pole_edge, + # backend="embedded", + backend="double_roundtrip", + offset_provider={"E2V": e2v, "V2E": v2e}, + # debug=True, + ) + + assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) + assert_close(3.5455427772565435e-003, max(pnabla_MXX)) + assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) + assert_close(3.3540113705465301e-003, max(pnabla_MYY)) From 3e4c765859852b533f3675301e2d1e3c7d2e2924 Mon Sep 17 00:00:00 2001 From: Felix Thaler Date: Mon, 25 Oct 2021 12:20:15 +0000 Subject: [PATCH 6/6] Iterator Transforms (#21) - Normalize IR (popup temporaries) - various cleanups --- setup.cfg | 2 +- src/iterator/__init__.py | 7 +- src/iterator/atlas_utils.py | 2 +- src/iterator/backend_executor.py | 5 +- src/iterator/backends/__init__.py | 2 +- src/iterator/backends/cpptoy.py | 18 +- src/iterator/backends/embedded.py | 75 ++++++-- src/iterator/backends/lisp.py | 27 ++- src/iterator/builtins.py | 5 + src/iterator/dispatcher.py | 13 +- src/iterator/embedded.py | 173 +++++++++++------- src/iterator/ir.py | 11 +- src/iterator/library.py | 5 +- src/iterator/runtime.py | 7 +- src/iterator/tracing.py | 30 +-- src/iterator/transforms/__init__.py | 4 + src/iterator/transforms/collect_shifts.py | 31 ++++ src/iterator/transforms/common.py | 21 +++ src/iterator/transforms/global_tmps.py | 121 ++++++++++++ src/iterator/transforms/inline_fundefs.py | 35 ++++ src/iterator/transforms/inline_lambdas.py | 32 ++++ src/iterator/transforms/inline_lifts.py | 37 ++++ src/iterator/transforms/normalize_shifts.py | 25 +++ src/iterator/transforms/popup_tmps.py | 60 ++++++ src/iterator/transforms/remap_symbols.py | 48 +++++ src/iterator/util/sym_validation.py | 16 +- tests/iterator_tests/conftest.py | 20 ++ tests/iterator_tests/fvm_nabla_setup.py | 18 +- tests/iterator_tests/hdiff_reference.py | 2 +- tests/iterator_tests/test_anton_toy.py | 33 ++-- .../test_cartesian_offset_provider.py | 5 +- tests/iterator_tests/test_column_stencil.py | 61 ++++-- tests/iterator_tests/test_fvm_nabla.py | 68 ++++--- tests/iterator_tests/test_hdiff.py | 21 +-- .../test_horizontal_indirection.py | 6 +- tests/iterator_tests/test_popup_tmps.py | 156 ++++++++++++++++ tests/iterator_tests/test_toy_connectivity.py | 76 ++++---- tests/iterator_tests/test_trivial.py | 53 ++++++ .../iterator_tests/test_vertical_advection.py | 17 +- 39 files changed, 1082 insertions(+), 266 deletions(-) create mode 100644 src/iterator/transforms/__init__.py create mode 100644 src/iterator/transforms/collect_shifts.py create mode 100644 src/iterator/transforms/common.py create mode 100644 src/iterator/transforms/global_tmps.py create mode 100644 src/iterator/transforms/inline_fundefs.py create mode 100644 src/iterator/transforms/inline_lambdas.py create mode 100644 src/iterator/transforms/inline_lifts.py create mode 100644 src/iterator/transforms/normalize_shifts.py create mode 100644 src/iterator/transforms/popup_tmps.py create mode 100644 src/iterator/transforms/remap_symbols.py create mode 100644 tests/iterator_tests/conftest.py create mode 100644 tests/iterator_tests/test_popup_tmps.py create mode 100644 tests/iterator_tests/test_trivial.py diff --git a/setup.cfg b/setup.cfg index 0b84c6ce00..37f41ecd0a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -183,7 +183,7 @@ lines_after_imports = 2 default_section = THIRDPARTY sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER known_first_party = eve,gtc,gt4py,__externals__,__gtscript__ -known_third_party = attr,black,boltons,cached_property,click,dace,dawn4py,devtools,factory,hypothesis,jinja2,mako,networkx,numpy,packaging,pkg_resources,pybind11,pydantic,pytest,pytest_factoryboy,setuptools,tabulate,tests,typing_extensions,typing_inspect,xxhash +known_third_party = atlas4py,attr,black,boltons,cached_property,click,dace,dawn4py,devtools,factory,hypothesis,jinja2,mako,networkx,numpy,packaging,pkg_resources,pybind11,pydantic,pytest,pytest_factoryboy,setuptools,tabulate,tests,typing_extensions,typing_inspect,xxhash #-- mypy -- [mypy] diff --git a/src/iterator/__init__.py b/src/iterator/__init__.py index 3d53598001..1f67e0c238 100644 --- a/src/iterator/__init__.py +++ b/src/iterator/__init__.py @@ -1,10 +1,9 @@ from typing import Optional, Union -from . import tracing -from .builtins import * -from .runtime import * +from . import builtins, runtime, tracing -__all__ = ["builtins", "runtime"] + +__all__ = ["builtins", "runtime", "tracing"] from packaging.version import LegacyVersion, Version, parse from pkg_resources import DistributionNotFound, get_distribution diff --git a/src/iterator/atlas_utils.py b/src/iterator/atlas_utils.py index dc6907186a..153abb6fa6 100644 --- a/src/iterator/atlas_utils.py +++ b/src/iterator/atlas_utils.py @@ -30,4 +30,4 @@ def __getitem__(self, indices): if neigh_index < 2: return self.atlas_connectivity[primary_index, neigh_index] else: - assert False + raise AssertionError() diff --git a/src/iterator/backend_executor.py b/src/iterator/backend_executor.py index 2231ef3d35..0f66174c8b 100644 --- a/src/iterator/backend_executor.py +++ b/src/iterator/backend_executor.py @@ -1,7 +1,8 @@ -from iterator.ir import Program -from iterator.backends import backend from devtools import debug +from iterator.backends import backend +from iterator.ir import Program + def execute_program(prog: Program, *args, **kwargs): assert "backend" in kwargs diff --git a/src/iterator/backends/__init__.py b/src/iterator/backends/__init__.py index ac69509c36..b926168790 100644 --- a/src/iterator/backends/__init__.py +++ b/src/iterator/backends/__init__.py @@ -1 +1 @@ -from . import cpptoy, lisp, embedded, double_roundtrip +from . import cpptoy, double_roundtrip, embedded, lisp diff --git a/src/iterator/backends/cpptoy.py b/src/iterator/backends/cpptoy.py index 8f81eb0f48..cb7e6bbd05 100644 --- a/src/iterator/backends/cpptoy.py +++ b/src/iterator/backends/cpptoy.py @@ -1,8 +1,11 @@ from typing import Any + from eve import codegen -from eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako -from iterator.ir import OffsetLiteral +from eve.codegen import FormatTemplate as as_fmt +from eve.codegen import MakoTemplate as as_mako from iterator.backends import backend +from iterator.ir import OffsetLiteral +from iterator.transforms import apply_common_transforms class ToyCpp(codegen.TemplatedGenerator): @@ -41,9 +44,16 @@ def visit_OffsetLiteral(self, node: OffsetLiteral, **kwargs): @classmethod def apply(cls, root, **kwargs: Any) -> str: - generated_code = super().apply(root, **kwargs) + transformed = apply_common_transforms( + root, + use_tmps=kwargs.get("use_tmps", False), + offset_provider=kwargs.get("offset_provider", None), + ) + generated_code = super().apply(transformed, **kwargs) formatted_code = codegen.format_source("cpp", generated_code, style="LLVM") return formatted_code -backend.register_backend("cpptoy", lambda prog, *args, **kwargs: print(ToyCpp.apply(prog))) +backend.register_backend( + "cpptoy", lambda prog, *args, **kwargs: print(ToyCpp.apply(prog, **kwargs)) +) diff --git a/src/iterator/backends/embedded.py b/src/iterator/backends/embedded.py index a495f556c8..718e5aa9af 100644 --- a/src/iterator/backends/embedded.py +++ b/src/iterator/backends/embedded.py @@ -1,11 +1,14 @@ +import importlib.util +import tempfile + +import iterator from eve import codegen -from eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako +from eve.codegen import FormatTemplate as as_fmt +from eve.codegen import MakoTemplate as as_mako from eve.concepts import Node -from iterator.ir import AxisLiteral, NoneLiteral, OffsetLiteral from iterator.backends import backend -import tempfile -import importlib.util -import iterator +from iterator.ir import AxisLiteral, FencilDefinition, OffsetLiteral +from iterator.transforms import apply_common_transforms class EmbeddedDSL(codegen.TemplatedGenerator): @@ -19,7 +22,7 @@ class EmbeddedDSL(codegen.TemplatedGenerator): AxisLiteral = as_fmt("{value}") StringLiteral = as_fmt("{value}") FunCall = as_fmt("{fun}({','.join(args)})") - Lambda = as_mako("lambda ${','.join(params)}: ${expr}") + Lambda = as_mako("(lambda ${','.join(params)}: ${expr})") StencilClosure = as_mako( "closure(${domain}, ${stencil}, [${','.join(outputs)}], [${','.join(inputs)}])" ) @@ -43,13 +46,58 @@ def ${id}(${','.join(params)}): ) +# TODO this wrapper should be replaced by an extension of the IR +class WrapperGenerator(EmbeddedDSL): + def visit_FencilDefinition(self, node: FencilDefinition, *, tmps): + params = self.visit(node.params) + non_tmp_params = [param for param in params if param not in tmps] + + body = [] + for tmp, domain in tmps.items(): + axis_literals = [named_range.args[0].value for named_range in domain.args] + origin = ( + "{" + + ", ".join( + f"{named_range.args[0].value}: -{self.visit(named_range.args[1])}" + for named_range in domain.args + ) + + "}" + ) + shape = ( + "(" + + ", ".join( + f"{self.visit(named_range.args[2])}-{self.visit(named_range.args[1])}" + for named_range in domain.args + ) + + ")" + ) + body.append( + f"{tmp} = np_as_located_field({','.join(axis_literals)}, origin={origin})(np.full({shape}, np.nan))" + ) + + body.append(f"{node.id}({','.join(params)}, **kwargs)") + body = "\n ".join(body) + return f"\ndef {node.id}_wrapper({','.join(non_tmp_params)}, **kwargs):\n {body}\n" + + _BACKEND_NAME = "embedded" def executor(ir: Node, *args, **kwargs): - debug = "debug" in kwargs and kwargs["debug"] == True + debug = "debug" in kwargs and kwargs["debug"] is True + use_tmps = "use_tmps" in kwargs and kwargs["use_tmps"] is True + + tmps = dict() + + def register_tmp(tmp, domain): + tmps[tmp] = domain + + ir = apply_common_transforms( + ir, use_tmps=use_tmps, offset_provider=kwargs["offset_provider"], register_tmp=register_tmp + ) program = EmbeddedDSL.apply(ir) + wrapper = WrapperGenerator.apply(ir, tmps=tmps) offset_literals = ( ir.iter_tree().if_isinstance(OffsetLiteral).getattr("value").if_isinstance(str).to_set() ) @@ -62,25 +110,28 @@ def executor(ir: Node, *args, **kwargs): if debug: print(tmp.name) header = """ +import numpy as np from iterator.builtins import * from iterator.runtime import * +from iterator.embedded import np_as_located_field """ - offset_literals = [f'{l} = offset("{l}")' for l in offset_literals] - axis_literals = [f'{l} = CartesianAxis("{l}")' for l in axis_literals] + offset_literals = [f'{o} = offset("{o}")' for o in offset_literals] + axis_literals = [f'{o} = CartesianAxis("{o}")' for o in axis_literals] tmp.write(header) tmp.write("\n".join(offset_literals)) tmp.write("\n") tmp.write("\n".join(axis_literals)) tmp.write("\n") tmp.write(program) + tmp.write(wrapper) tmp.flush() spec = importlib.util.spec_from_file_location("module.name", tmp.name) foo = importlib.util.module_from_spec(spec) - spec.loader.exec_module(foo) + spec.loader.exec_module(foo) # type: ignore fencil_name = ir.fencil_definitions[0].id - fencil = getattr(foo, fencil_name) + fencil = getattr(foo, fencil_name + "_wrapper") assert "offset_provider" in kwargs new_kwargs = {} @@ -88,7 +139,7 @@ def executor(ir: Node, *args, **kwargs): if "column_axis" in kwargs: new_kwargs["column_axis"] = kwargs["column_axis"] - if not "dispatch_backend" in kwargs: + if "dispatch_backend" not in kwargs: iterator.builtins.builtin_dispatch.push_key("embedded") fencil(*args, **new_kwargs) iterator.builtins.builtin_dispatch.pop_key() diff --git a/src/iterator/backends/lisp.py b/src/iterator/backends/lisp.py index c4eed300dd..acf32b08e6 100644 --- a/src/iterator/backends/lisp.py +++ b/src/iterator/backends/lisp.py @@ -1,10 +1,9 @@ from typing import Any -from eve.codegen import TemplatedGenerator -from eve.codegen import FormatTemplate as as_fmt - -# from yasi import indent_code +from eve.codegen import FormatTemplate as as_fmt +from eve.codegen import TemplatedGenerator from iterator.backends import backend +from iterator.transforms import apply_common_transforms class ToLispLike(TemplatedGenerator): @@ -50,11 +49,19 @@ class ToLispLike(TemplatedGenerator): @classmethod def apply(cls, root, **kwargs: Any) -> str: - generated_code = super().apply(root, **kwargs) - return generated_code - # indented = indent_code(generated_code, "--dialect lisp") - # formatted_code = "".join(indented["indented_code"]) - # return formatted_code + transformed = apply_common_transforms( + root, use_tmps=kwargs.get("use_tmps", False), offset_provider=kwargs["offset_provider"] + ) + generated_code = super().apply(transformed, **kwargs) + try: + from yasi import indent_code + + indented = indent_code(generated_code, "--dialect lisp") + return "".join(indented["indented_code"]) + except ImportError: + return generated_code -backend.register_backend("lisp", lambda prog, *args, **kwargs: print(ToLispLike.apply(prog))) +backend.register_backend( + "lisp", lambda prog, *args, **kwargs: print(ToLispLike.apply(prog, **kwargs)) +) diff --git a/src/iterator/builtins.py b/src/iterator/builtins.py index 5dfd0a18fa..c762178c75 100644 --- a/src/iterator/builtins.py +++ b/src/iterator/builtins.py @@ -1,5 +1,6 @@ from iterator.dispatcher import Dispatcher + __all__ = [ "deref", "shift", @@ -20,6 +21,10 @@ "greater", "make_tuple", "nth", + "plus", + "reduce", + "scan", + "shift", ] builtin_dispatch = Dispatcher() diff --git a/src/iterator/dispatcher.py b/src/iterator/dispatcher.py index 15e9a93865..9f14f55b87 100644 --- a/src/iterator/dispatcher.py +++ b/src/iterator/dispatcher.py @@ -1,4 +1,5 @@ -from typing import Any +from typing import Any, Callable, Dict, List + # TODO test @@ -12,9 +13,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: if self.dispatcher.key is None: return self.fun(*args, **kwargs) else: - return self.dispatcher._funs[self.dispatcher.key][self.fun.__name__]( - *args, **kwargs - ) + return self.dispatcher._funs[self.dispatcher.key][self.fun.__name__](*args, **kwargs) def register(self, key): self.dispatcher.register_key(key) @@ -27,15 +26,15 @@ def _impl(fun): class Dispatcher: def __init__(self) -> None: - self._funs = {} - self.key_stack = [] + self._funs: Dict[str, Dict[str, Callable]] = {} + self.key_stack: List[str] = [] @property def key(self): return self.key_stack[-1] if self.key_stack else None def register_key(self, key): - if not key in self._funs: + if key not in self._funs: self._funs[key] = {} def push_key(self, key): diff --git a/src/iterator/embedded.py b/src/iterator/embedded.py index 679d065b94..bfe0388e7e 100644 --- a/src/iterator/embedded.py +++ b/src/iterator/embedded.py @@ -1,32 +1,14 @@ -from dataclasses import dataclass import itertools +import numbers +from dataclasses import dataclass + +import numpy as np import iterator -from iterator.builtins import ( - builtin_dispatch, - is_none, - lift, - reduce, - shift, - deref, - scan, - domain, - named_range, - if_, - or_, - minus, - plus, - mul, - div, - greater, - eq, - nth, - make_tuple, -) +from iterator import builtins from iterator.runtime import CartesianAxis, Offset from iterator.utils import tupelize -import numpy as np -import numbers + EMBEDDED = "embedded" @@ -39,37 +21,37 @@ def __init__(self, tbl, origin_axis, neighbor_axis, max_neighbors) -> None: self.max_neighbors = max_neighbors -@deref.register(EMBEDDED) -def deref(iter): - return iter.deref() +@builtins.deref.register(EMBEDDED) +def deref(it): + return it.deref() -@if_.register(EMBEDDED) +@builtins.if_.register(EMBEDDED) def if_(cond, t, f): return t if cond else f -@or_.register(EMBEDDED) +@builtins.or_.register(EMBEDDED) def or_(a, b): return a or b -@nth.register(EMBEDDED) +@builtins.nth.register(EMBEDDED) def nth(i, tup): return tup[i] -@make_tuple.register(EMBEDDED) +@builtins.make_tuple.register(EMBEDDED) def make_tuple(*args): return (*args,) -@lift.register(EMBEDDED) +@builtins.lift.register(EMBEDDED) def lift(stencil): def impl(*args): class wrap_iterator: - def __init__(self, *, offsets=[], elem=None) -> None: - self.offsets = offsets + def __init__(self, *, offsets=None, elem=None) -> None: + self.offsets = offsets or [] self.elem = elem # TODO needs to be supported by all iterators that represent tuples @@ -105,21 +87,21 @@ def deref(self): return impl -@reduce.register(EMBEDDED) +@builtins.reduce.register(EMBEDDED) def reduce(fun, init): def sten(*iters): - # assert check_that_all_iterators_are_compatible(*iters) + # TODO: assert check_that_all_iterators_are_compatible(*iters) first_it = iters[0] n = first_it.max_neighbors() res = init for i in range(n): # we can check a single argument # because all arguments share the same pattern - if iterator.builtins.deref(iterator.builtins.shift(i)(first_it)) is None: + if builtins.deref(builtins.shift(i)(first_it)) is None: break res = fun( res, - *(iterator.builtins.deref(iterator.builtins.shift(i)(it)) for it in iters), + *(builtins.deref(builtins.shift(i)(it)) for it in iters), ) return res @@ -127,7 +109,7 @@ def sten(*iters): class _None: - """Dummy object to allow execution of expression containing Nones in non-active path + """Dummy object to allow execution of expression containing Nones in non-active path. E.g. `if_(is_none(state), 42, 42+state)` @@ -164,12 +146,12 @@ def __getitem__(self, i): return _None() -@is_none.register(EMBEDDED) +@builtins.is_none.register(EMBEDDED) def is_none(arg): return isinstance(arg, _None) -@domain.register(EMBEDDED) +@builtins.domain.register(EMBEDDED) def domain(*args): domain = {} for arg in args: @@ -177,49 +159,51 @@ def domain(*args): return domain -@named_range.register(EMBEDDED) +@builtins.named_range.register(EMBEDDED) def named_range(tag, start, end): return {tag: range(start, end)} -@minus.register(EMBEDDED) +@builtins.minus.register(EMBEDDED) def minus(first, second): return first - second -@plus.register(EMBEDDED) +@builtins.plus.register(EMBEDDED) def plus(first, second): return first + second -@mul.register(EMBEDDED) +@builtins.mul.register(EMBEDDED) def mul(first, second): return first * second -@div.register(EMBEDDED) +@builtins.div.register(EMBEDDED) def div(first, second): return first / second -@eq.register(EMBEDDED) +@builtins.eq.register(EMBEDDED) def eq(first, second): return first == second -@greater.register(EMBEDDED) +@builtins.greater.register(EMBEDDED) def greater(first, second): return first > second -def named_range(axis, range): - return ((axis, i) for i in range) +def named_range_(axis, range_): + return ((axis, i) for i in range_) def domain_iterator(domain): return ( dict(elem) - for elem in itertools.product(*map(lambda tup: named_range(tup[0], tup[1]), domain.items())) + for elem in itertools.product( + *map(lambda tup: named_range_(tup[0], tup[1]), domain.items()) + ) ) @@ -247,7 +231,7 @@ def execute_shift(pos, tag, index, *, offset_provider): ] return new_pos - assert False + raise AssertionError() # The following holds for shifts: @@ -287,13 +271,62 @@ def get_open_offsets(*offsets): return group_offsets(*offsets)[1] +class Undefined: + def __float__(self): + return np.nan + + @classmethod + def _setup_math_operations(cls): + ops = [ + "__add__", + "__sub__", + "__mul__", + "__matmul__", + "__truediv__", + "__floordiv__", + "__mod__", + "__divmod__", + "__pow__", + "__lshift__", + "__rshift__", + "__and__", + "__xor__", + "__or__", + "__radd__", + "__rsub__", + "__rmul__", + "__rmatmul__", + "__rtruediv__", + "__rfloordiv__", + "__rmod__", + "__rdivmod__", + "__rpow__", + "__rlshift__", + "__rrshift__", + "__rand__", + "__rxor__", + "__ror__", + "__neg__", + "__pos__", + "__abs__", + "__invert__", + ] + for op in ops: + setattr(cls, op, lambda self, *args, **kwargs: _UNDEFINED) + + +Undefined._setup_math_operations() + +_UNDEFINED = Undefined() + + class MDIterator: def __init__( - self, field, pos, *, incomplete_offsets=[], offset_provider, column_axis=None + self, field, pos, *, incomplete_offsets=None, offset_provider, column_axis=None ) -> None: self.field = field self.pos = pos - self.incomplete_offsets = incomplete_offsets + self.incomplete_offsets = incomplete_offsets or [] self.offset_provider = offset_provider self.column_axis = column_axis @@ -331,7 +364,10 @@ def deref(self): shifted_pos, slice_axises=slice_column, ) - return self.field[ordered_indices] + try: + return self.field[ordered_indices] + except IndexError: + return _UNDEFINED def make_in_iterator(inp, pos, offset_provider, *, column_axis): @@ -352,7 +388,7 @@ def make_in_iterator(inp, pos, offset_provider, *, column_axis): ) -builtin_dispatch.push_key(EMBEDDED) # makes embedded the default +builtins.builtin_dispatch.push_key(EMBEDDED) # makes embedded the default class LocatedField: @@ -388,8 +424,9 @@ def shape(self): return self.array().shape -def get_ordered_indices(axises, pos, *, slice_axises={}): - """pos is a dictionary from axis to offset""" +def get_ordered_indices(axises, pos, *, slice_axises=None): + """pos is a dictionary from axis to offset.""" # noqa: D403 + slice_axises = slice_axises or dict() assert all(axis in [*pos.keys(), *slice_axises] for axis in axises) return tuple(pos[axis] if axis in pos else slice_axises[axis] for axis in axises) @@ -444,10 +481,10 @@ def index_field(axis): return LocatedField(lambda index: index[0], (axis,)) -@iterator.builtins.shift.register(EMBEDDED) +@builtins.shift.register(EMBEDDED) def shift(*offsets): - def impl(iter): - return iter.shift(*offsets) + def impl(it): + return it.shift(*offsets) return impl @@ -455,13 +492,13 @@ def impl(iter): @dataclass class Column: axis: CartesianAxis - range: range + range: range # noqa: A003 class ScanArgIterator: - def __init__(self, wrapped_iter, k_pos, *, offsets=[]) -> None: + def __init__(self, wrapped_iter, k_pos, *, offsets=None) -> None: self.wrapped_iter = wrapped_iter - self.offsets = offsets + self.offsets = offsets or [] self.k_pos = k_pos def deref(self): @@ -472,13 +509,13 @@ def shift(self, *offsets): def shifted_scan_arg(k_pos): - def impl(iter): - return ScanArgIterator(iter, k_pos=k_pos) + def impl(it): + return ScanArgIterator(it, k_pos=k_pos) return impl -def fendef_embedded(fun, *args, **kwargs): +def fendef_embedded(fun, *args, **kwargs): # noqa: 536 assert "offset_provider" in kwargs @iterator.runtime.closure.register(EMBEDDED) @@ -490,7 +527,7 @@ def closure(domain, sten, outs, ins): # domain is Dict[axis, range] column = Column(_column_axis, domain[_column_axis]) del domain[_column_axis] - @iterator.builtins.scan.register( + @builtins.scan.register( EMBEDDED ) # TODO this is a bit ugly, alternative: pass scan range via iterator def scan(scan_pass, is_forward, init): diff --git a/src/iterator/ir.py b/src/iterator/ir.py index f2214a4dd1..58ca48d52f 100644 --- a/src/iterator/ir.py +++ b/src/iterator/ir.py @@ -1,4 +1,5 @@ from typing import List, Union + from eve import Node from eve.traits import SymbolName, SymbolTableTrait from eve.type_definitions import SymbolRef @@ -6,7 +7,7 @@ class Sym(Node): # helper - id: SymbolName + id: SymbolName # noqa: A003 class Expr(Node): @@ -42,7 +43,7 @@ class AxisLiteral(Expr): class SymRef(Expr): - id: SymbolRef + id: SymbolRef # noqa: A003 class Lambda(Expr, SymbolTableTrait): @@ -56,7 +57,7 @@ class FunCall(Expr): class FunctionDefinition(Node, SymbolTableTrait): - id: SymbolName + id: SymbolName # noqa: A003 params: List[Sym] expr: Expr @@ -68,7 +69,7 @@ def __hash__(self): class Setq(Node): - id: SymbolName + id: SymbolName # noqa: A003 expr: Expr @@ -80,7 +81,7 @@ class StencilClosure(Node): class FencilDefinition(Node, SymbolTableTrait): - id: SymbolName + id: SymbolName # noqa: A003 params: List[Sym] closures: List[StencilClosure] diff --git a/src/iterator/library.py b/src/iterator/library.py index b0e1dbfb09..18c87aeb49 100644 --- a/src/iterator/library.py +++ b/src/iterator/library.py @@ -1,7 +1,7 @@ from iterator.builtins import reduce -def sum(fun=None): +def sum_(fun=None): if fun is None: return reduce(lambda a, b: a + b, 0) else: @@ -10,6 +10,3 @@ def sum(fun=None): def dot(a, b): return reduce(lambda acc, a_n, c_n: acc + a_n * c_n, 0)(a, b) - - -# dot = reduce(lambda a, b, c: a + b * c, 0) diff --git a/src/iterator/runtime.py b/src/iterator/runtime.py index 5b74a456c4..deca4a6e33 100644 --- a/src/iterator/runtime.py +++ b/src/iterator/runtime.py @@ -1,14 +1,15 @@ -from typing import Union from dataclasses import dataclass +from typing import Callable, Dict, Optional, Union from iterator.builtins import BackendNotSelectedError, builtin_dispatch + __all__ = ["offset", "fundef", "fendef", "closure", "CartesianAxis"] @dataclass class Offset: - value: Union[int, str] = None + value: Optional[Union[int, str]] = None def __hash__(self) -> int: return hash(self.value) @@ -26,7 +27,7 @@ def __hash__(self) -> int: return hash(self.value) -fendef_registry = {} +fendef_registry: Dict[Optional[Callable], Callable] = {} # TODO the dispatching is linear, not sure if there is an easy way to make it constant diff --git a/src/iterator/tracing.py b/src/iterator/tracing.py index cca543ad87..49731d23ef 100644 --- a/src/iterator/tracing.py +++ b/src/iterator/tracing.py @@ -1,9 +1,12 @@ -from iterator.runtime import CartesianAxis -from eve import Node import inspect +from typing import List + +import iterator +from eve import Node from iterator.backend_executor import execute_program from iterator.ir import ( AxisLiteral, + BoolLiteral, Expr, FencilDefinition, FloatLiteral, @@ -12,15 +15,14 @@ IntLiteral, Lambda, NoneLiteral, - BoolLiteral, OffsetLiteral, Program, StencilClosure, Sym, SymRef, ) -from iterator.backends import backend -import iterator +from iterator.runtime import CartesianAxis + TRACING = "tracing" @@ -58,10 +60,6 @@ def __truediv__(self, other): def __sub__(self, other): return FunCall(fun=SymRef(id="minus"), args=[self, make_node(other)]) - @monkeypatch_method(Expr) - def __eq__(self, other): - return FunCall(fun=SymRef(id="eq"), args=[self, make_node(other)]) - @monkeypatch_method(Expr) def __gt__(self, other): return FunCall(fun=SymRef(id="greater"), args=[self, make_node(other)]) @@ -83,8 +81,8 @@ def __call__(self, *args): return FunCall(fun=SymRef(id=str(self.id)), args=[*make_node(args)]) -def _s(id): - return SymRef(id=id) +def _s(id_): + return SymRef(id=id_) def trace_function_argument(arg): @@ -225,6 +223,8 @@ def make_node(o): return list(make_node(arg) for arg in o) if o is None: return NoneLiteral() + if isinstance(o, iterator.runtime.FundefDispatcher): + return SymRef(id=o.fun.__name__) raise NotImplementedError(f"Cannot handle {o}") @@ -266,12 +266,12 @@ def __bool__(self): class Tracer: - fundefs = [] - closures = [] + fundefs: List[FunctionDefinition] = [] + closures: List[StencilClosure] = [] @classmethod def add_fundef(cls, fun): - if not fun in cls.fundefs: + if fun not in cls.fundefs: cls.fundefs.append(fun) @classmethod @@ -293,7 +293,7 @@ def closure(domain, stencil, outputs, inputs): Tracer.add_closure( StencilClosure( domain=domain, - stencil=SymRef(id=str(stencil.__name__)), + stencil=make_node(stencil), outputs=outputs, inputs=inputs, ) diff --git a/src/iterator/transforms/__init__.py b/src/iterator/transforms/__init__.py new file mode 100644 index 0000000000..1fc5154fa0 --- /dev/null +++ b/src/iterator/transforms/__init__.py @@ -0,0 +1,4 @@ +from iterator.transforms.common import apply_common_transforms + + +__all__ = ["apply_common_transforms"] diff --git a/src/iterator/transforms/collect_shifts.py b/src/iterator/transforms/collect_shifts.py new file mode 100644 index 0000000000..9e326295de --- /dev/null +++ b/src/iterator/transforms/collect_shifts.py @@ -0,0 +1,31 @@ +from typing import Dict, List + +from eve import NodeVisitor +from iterator import ir + + +class CollectShifts(NodeVisitor): + def visit_FunCall(self, node: ir.FunCall, *, shifts: Dict[str, List[tuple]]): + if isinstance(node.fun, ir.SymRef) and node.fun.id == "deref": + assert len(node.args) == 1 + arg = node.args[0] + if isinstance(arg, ir.SymRef): + # direct deref of a symbol: deref(sym) + shifts.setdefault(arg.id, []).append(()) + elif ( + isinstance(arg, ir.FunCall) + and isinstance(arg.fun, ir.FunCall) + and isinstance(arg.fun.fun, ir.SymRef) + and arg.fun.fun.id == "shift" + and isinstance(arg.args[0], ir.SymRef) + ): + # deref of a shifted symbol: deref(shift(...)(sym)) + assert len(arg.args) == 1 + sym = arg.args[0] + shift_args = arg.fun.args + shifts.setdefault(sym.id, []).append(tuple(shift_args)) + else: + raise RuntimeError(f"Unexpected node: {node}") + elif isinstance(node.fun, ir.SymRef) and node.fun.id in ("lift", "scan"): + raise RuntimeError(f"Unsupported node: {node}") + return self.generic_visit(node, shifts=shifts) diff --git a/src/iterator/transforms/common.py b/src/iterator/transforms/common.py new file mode 100644 index 0000000000..349b5ad237 --- /dev/null +++ b/src/iterator/transforms/common.py @@ -0,0 +1,21 @@ +from iterator.transforms.global_tmps import CreateGlobalTmps +from iterator.transforms.inline_fundefs import InlineFundefs, PruneUnreferencedFundefs +from iterator.transforms.inline_lambdas import InlineLambdas +from iterator.transforms.inline_lifts import InlineLifts +from iterator.transforms.normalize_shifts import NormalizeShifts + + +def apply_common_transforms(ir, use_tmps=False, offset_provider=None, register_tmp=None): + ir = InlineFundefs().visit(ir) + ir = PruneUnreferencedFundefs().visit(ir) + ir = NormalizeShifts().visit(ir) + if not use_tmps: + ir = InlineLifts().visit(ir) + ir = InlineLambdas().visit(ir) + ir = NormalizeShifts().visit(ir) + if use_tmps: + assert offset_provider is not None + ir = CreateGlobalTmps().visit( + ir, offset_provider=offset_provider, register_tmp=register_tmp + ) + return ir diff --git a/src/iterator/transforms/global_tmps.py b/src/iterator/transforms/global_tmps.py new file mode 100644 index 0000000000..2d83524159 --- /dev/null +++ b/src/iterator/transforms/global_tmps.py @@ -0,0 +1,121 @@ +from typing import Dict, List + +from eve import NodeTranslator +from iterator import ir +from iterator.runtime import CartesianAxis +from iterator.transforms.collect_shifts import CollectShifts +from iterator.transforms.popup_tmps import PopupTmps + + +class CreateGlobalTmps(NodeTranslator): + @staticmethod + def _extend_domain(domain: ir.FunCall, offset_provider, shifts): + assert isinstance(domain.fun, ir.SymRef) and domain.fun.id == "domain" + assert all(isinstance(o, CartesianAxis) for o in offset_provider.values()) + + offset_limits = {k: (0, 0) for k in offset_provider.keys()} + + for shift in shifts: + offsets = {k: 0 for k in offset_provider.keys()} + for k, v in zip(shift[0::2], shift[1::2]): + offsets[k.value] += v.value + for k, v in offsets.items(): + old_min, old_max = offset_limits[k] + offset_limits[k] = (min(old_min, v), max(old_max, v)) + + offset_limits = {v.value: offset_limits[k] for k, v in offset_provider.items()} + + named_ranges = [] + for named_range in domain.args: + assert ( + isinstance(named_range, ir.FunCall) + and isinstance(named_range.fun, ir.SymRef) + and named_range.fun.id == "named_range" + ) + axis_literal, lower_bound, upper_bound = named_range.args + assert isinstance(axis_literal, ir.AxisLiteral) + + lower_offset, upper_offset = offset_limits.get(axis_literal.value, (0, 0)) + named_ranges.append( + ir.FunCall( + fun=named_range.fun, + args=[ + axis_literal, + ir.FunCall( + fun=ir.SymRef(id="plus"), + args=[lower_bound, ir.IntLiteral(value=lower_offset)], + ) + if lower_offset + else lower_bound, + ir.FunCall( + fun=ir.SymRef(id="plus"), + args=[upper_bound, ir.IntLiteral(value=upper_offset)], + ) + if upper_offset + else upper_bound, + ], + ) + ) + + return ir.FunCall(fun=domain.fun, args=named_ranges) + + def visit_FencilDefinition(self, node: ir.FencilDefinition, *, offset_provider, register_tmp): + tmps: List[ir.Sym] = [] + + def handle_arg(arg): + if isinstance(arg, ir.SymRef): + return arg + if ( + isinstance(arg, ir.FunCall) + and isinstance(arg.fun, ir.FunCall) + and arg.fun.fun.id == "lift" + ): + ref = ir.SymRef(id=f"tmp{len(tmps)}") + tmps.append(ir.Sym(id=ref.id)) + assert len(arg.fun.args) == 1 + unlifted = ir.FunCall(fun=arg.fun.args[0], args=arg.args) + todos.append(([ref], unlifted)) + return ref + raise AssertionError() + + closures = [] + tmp_domains = dict() + for closure in reversed(node.closures): + assert isinstance(closure.stencil, ir.Lambda) + wrapped_stencil = ir.FunCall(fun=closure.stencil, args=closure.inputs) + popped_stencil = PopupTmps().visit(wrapped_stencil) + todos = [(closure.outputs, popped_stencil)] + + shifts: Dict[str, List[tuple]] = dict() + domain = closure.domain + while todos: + outputs, call = todos.pop() + output_shifts: List[tuple] = sum((shifts.get(o.id, []) for o in outputs), []) + domain = self._extend_domain(domain, offset_provider, output_shifts) + for output in outputs: + if output.id in {tmp.id for tmp in tmps}: + assert output.id not in tmp_domains + tmp_domains[output.id] = domain + closure = ir.StencilClosure( + domain=domain, + stencil=call.fun, + outputs=outputs, + inputs=[handle_arg(arg) for arg in call.args], + ) + local_shifts: Dict[str, List[tuple]] = dict() + CollectShifts().visit(closure.stencil, shifts=local_shifts) + input_map = { + param.id: arg.id for param, arg in zip(closure.stencil.params, closure.inputs) + } + for k, v in local_shifts.items(): + shifts.setdefault(input_map[k], []).extend(v) + closures.append(closure) + + assert {tmp.id for tmp in tmps} == set(tmp_domains.keys()) + if register_tmp is not None: + for tmp, domain in tmp_domains.items(): + register_tmp(tmp, domain) + + return ir.FencilDefinition( + id=node.id, params=node.params + tmps, closures=list(reversed(closures)) + ) diff --git a/src/iterator/transforms/inline_fundefs.py b/src/iterator/transforms/inline_fundefs.py new file mode 100644 index 0000000000..308fef6953 --- /dev/null +++ b/src/iterator/transforms/inline_fundefs.py @@ -0,0 +1,35 @@ +from typing import Any, Dict, Set + +from eve import NOTHING, NodeTranslator +from iterator import ir + + +class InlineFundefs(NodeTranslator): + def visit_SymRef(self, node: ir.SymRef, *, symtable: Dict[str, Any]): + if node.id in symtable and isinstance((symbol := symtable[node.id]), ir.FunctionDefinition): + return ir.Lambda( + params=self.generic_visit(symbol.params, symtable=symtable), + expr=self.generic_visit(symbol.expr, symtable=symtable), + ) + return self.generic_visit(node) + + def visit_Program(self, node: ir.Program): + return self.generic_visit(node, symtable=node.symtable_) + + +class PruneUnreferencedFundefs(NodeTranslator): + def visit_FunctionDefinition( + self, node: ir.FunctionDefinition, *, referenced: Set[str], second_pass: bool + ): + if second_pass and node.id not in referenced: + return NOTHING + return self.generic_visit(node, referenced=referenced, second_pass=second_pass) + + def visit_SymRef(self, node: ir.SymRef, *, referenced: Set[str], second_pass: bool): + referenced.add(node.id) + return node + + def visit_Program(self, node: ir.Program): + referenced: Set[str] = set() + self.generic_visit(node, referenced=referenced, second_pass=False) + return self.generic_visit(node, referenced=referenced, second_pass=True) diff --git a/src/iterator/transforms/inline_lambdas.py b/src/iterator/transforms/inline_lambdas.py new file mode 100644 index 0000000000..c38627acad --- /dev/null +++ b/src/iterator/transforms/inline_lambdas.py @@ -0,0 +1,32 @@ +from eve import NodeTranslator +from iterator import ir +from iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols + + +class InlineLambdas(NodeTranslator): + def visit_FunCall(self, node: ir.FunCall): + node = self.generic_visit(node) + if isinstance(node.fun, ir.Lambda): + assert len(node.fun.params) == len(node.args) + refs = set.union( + *( + arg.iter_tree().if_isinstance(ir.SymRef).getattr("id").to_set() + for arg in node.args + ) + ) + syms = node.fun.expr.iter_tree().if_isinstance(ir.Sym).getattr("id").to_set() + clashes = refs & syms + expr = node.fun.expr + if clashes: + + def new_name(name): + while name in refs or name in syms: + name += "_" + return name + + name_map = {sym: new_name(sym) for sym in clashes} + expr = RenameSymbols().visit(expr, name_map=name_map) + + symbol_map = {param.id: arg for param, arg in zip(node.fun.params, node.args)} + return RemapSymbolRefs().visit(expr, symbol_map=symbol_map) + return node diff --git a/src/iterator/transforms/inline_lifts.py b/src/iterator/transforms/inline_lifts.py new file mode 100644 index 0000000000..02af36f799 --- /dev/null +++ b/src/iterator/transforms/inline_lifts.py @@ -0,0 +1,37 @@ +from eve import NodeTranslator +from iterator import ir + + +class InlineLifts(NodeTranslator): + def visit_FunCall(self, node: ir.FunCall): + node = self.generic_visit(node) + if isinstance(node.fun, ir.SymRef) and node.fun.id == "deref": + assert len(node.args) == 1 + if ( + isinstance(node.args[0], ir.FunCall) + and isinstance(node.args[0].fun, ir.FunCall) + and isinstance(node.args[0].fun.fun, ir.SymRef) + and node.args[0].fun.fun.id == "lift" + ): + # deref(lift(f)(args...)) -> f(args...) + assert len(node.args[0].fun.args) == 1 + f = node.args[0].fun.args[0] + args = node.args[0].args + return ir.FunCall(fun=f, args=args) + elif ( + isinstance(node.args[0], ir.FunCall) + and isinstance(node.args[0].fun, ir.FunCall) + and isinstance(node.args[0].fun.fun, ir.SymRef) + and node.args[0].fun.fun.id == "shift" + and isinstance(node.args[0].args[0], ir.FunCall) + and isinstance(node.args[0].args[0].fun, ir.FunCall) + and isinstance(node.args[0].args[0].fun.fun, ir.SymRef) + and node.args[0].args[0].fun.fun.id == "lift" + ): + # deref(shift(...)(lift(f)(args...)) -> f(shift(...)(args)...) + f = node.args[0].args[0].fun.args[0] + shift = node.args[0].fun + args = node.args[0].args[0].args + res = ir.FunCall(fun=f, args=[ir.FunCall(fun=shift, args=[arg]) for arg in args]) + return res + return node diff --git a/src/iterator/transforms/normalize_shifts.py b/src/iterator/transforms/normalize_shifts.py new file mode 100644 index 0000000000..500ff232cf --- /dev/null +++ b/src/iterator/transforms/normalize_shifts.py @@ -0,0 +1,25 @@ +from eve import NodeTranslator +from iterator import ir + + +class NormalizeShifts(NodeTranslator): + def visit_FunCall(self, node: ir.FunCall): + node = self.generic_visit(node) + if ( + isinstance(node.fun, ir.FunCall) + and isinstance(node.fun.fun, ir.SymRef) + and node.fun.fun.id == "shift" + and node.args + and isinstance(node.args[0], ir.FunCall) + and isinstance(node.args[0].fun.fun, ir.SymRef) + and node.args[0].fun.fun.id == "shift" + ): + # shift(args1...)(shift(args2...)(it)) -> shift(args2..., args1...)(it) + assert len(node.args) == 1 + return ir.FunCall( + fun=ir.FunCall( + fun=ir.SymRef(id="shift"), args=node.args[0].fun.args + node.fun.args + ), + args=node.args[0].args, + ) + return node diff --git a/src/iterator/transforms/popup_tmps.py b/src/iterator/transforms/popup_tmps.py new file mode 100644 index 0000000000..b67edb5cb3 --- /dev/null +++ b/src/iterator/transforms/popup_tmps.py @@ -0,0 +1,60 @@ +from typing import Dict, Optional + +from eve import NodeTranslator +from iterator import ir +from iterator.transforms.remap_symbols import RemapSymbolRefs + + +class PopupTmps(NodeTranslator): + _counter = 0 + + def visit_FunCall(self, node: ir.FunCall, *, lifts: Optional[Dict[str, ir.Node]] = None): + if ( + isinstance(node.fun, ir.FunCall) + and isinstance(node.fun.fun, ir.SymRef) + and node.fun.fun.id == "lift" + ): + # lifted lambda call + assert len(node.fun.args) == 1 and isinstance(node.fun.args[0], ir.Lambda) + assert lifts is not None + + nested_lifts: Dict[str, ir.Node] = dict() + res = self.generic_visit(node, lifts=nested_lifts) + # TODO: avoid possible symbol name clashes + symbol = f"t{self._counter}" + self._counter += 1 + + symbol_map = {param.id: arg for param, arg in zip(res.fun.args[0].params, res.args)} + new_args = [ + RemapSymbolRefs().visit(arg, symbol_map=symbol_map) for arg in nested_lifts.values() + ] + assert len(res.fun.args[0].params) == len(res.args + new_args) + call = ir.FunCall(fun=res.fun, args=res.args + new_args) + + # return existing definition if the same expression was lifted before + for k, v in lifts.items(): + if call == v: + return ir.SymRef(id=k) + + lifts[symbol] = call + return ir.SymRef(id=symbol) + elif isinstance(node.fun, ir.Lambda): + # direct lambda call + lifts = dict() + res = self.generic_visit(node, lifts=lifts) + symbol_map = {param.id: arg for param, arg in zip(res.fun.params, res.args)} + new_args = [ + RemapSymbolRefs().visit(arg, symbol_map=symbol_map) for arg in lifts.values() + ] + assert len(res.fun.params) == len(res.args + new_args) + return ir.FunCall(fun=res.fun, args=res.args + new_args) + + return self.generic_visit(node, lifts=lifts) + + def visit_Lambda(self, node: ir.Lambda, *, lifts): + node = self.generic_visit(node, lifts=lifts) + if not lifts: + return node + + new_params = [ir.Sym(id=param) for param in lifts.keys()] + return ir.Lambda(params=node.params + new_params, expr=node.expr) diff --git a/src/iterator/transforms/remap_symbols.py b/src/iterator/transforms/remap_symbols.py new file mode 100644 index 0000000000..8a17944b59 --- /dev/null +++ b/src/iterator/transforms/remap_symbols.py @@ -0,0 +1,48 @@ +from typing import Any, Dict, Optional, Set + +from eve import NodeTranslator +from iterator import ir + + +class RemapSymbolRefs(NodeTranslator): + def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): + return symbol_map.get(node.id, node) + + def visit_Lambda(self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node]): + params = {str(p.id) for p in node.params} + new_symbol_map = {k: v for k, v in symbol_map.items() if k not in params} + return ir.Lambda( + params=node.params, + expr=self.generic_visit(node.expr, symbol_map=new_symbol_map), + ) + + def generic_visit(self, node: ir.Node, **kwargs: Any): + assert isinstance(node, ir.SymbolTableTrait) == isinstance( + node, ir.Lambda + ), "found unexpected new symbol scope" + return super().generic_visit(node, **kwargs) + + +class RenameSymbols(NodeTranslator): + def visit_Sym( + self, node: ir.Sym, *, name_map: Dict[str, str], active: Optional[Set[str]] = None + ): + if active and node.id in active: + return ir.Sym(id=name_map.get(node.id, node.id)) + return node + + def visit_SymRef( + self, node: ir.SymRef, *, name_map: Dict[str, str], active: Optional[Set[str]] = None + ): + if active and node.id in active: + return ir.SymRef(id=name_map.get(node.id, node.id)) + return node + + def generic_visit( + self, node: ir.Node, *, name_map: Dict[str, str], active: Optional[Set[str]] = None + ): + if isinstance(node, ir.SymbolTableTrait): + if active is None: + active = set() + active = active | set(node.symtable_) + return super().generic_visit(node, name_map=name_map, active=active) diff --git a/src/iterator/util/sym_validation.py b/src/iterator/util/sym_validation.py index a8432137aa..17af043739 100644 --- a/src/iterator/util/sym_validation.py +++ b/src/iterator/util/sym_validation.py @@ -1,10 +1,12 @@ from typing import Any, Dict, List, Type + +import pydantic + +from eve import Node from eve.traits import SymbolTableTrait from eve.type_definitions import SymbolRef -from eve.visitors import NodeVisitor -import pydantic from eve.typingx import RootValidatorType, RootValidatorValuesType -from eve import Node +from eve.visitors import NodeVisitor def validate_symbol_refs() -> RootValidatorType: @@ -17,9 +19,7 @@ class SymtableValidator(NodeVisitor): def __init__(self) -> None: self.missing_symbols: List[str] = [] - def visit_Node( - self, node: Node, *, symtable: Dict[str, Any], **kwargs: Any - ) -> None: + def visit_Node(self, node: Node, *, symtable: Dict[str, Any], **kwargs: Any) -> None: for name, metadata in node.__node_children__.items(): if isinstance(metadata["definition"].type_, type) and issubclass( metadata["definition"].type_, SymbolRef @@ -39,9 +39,7 @@ def apply(cls, node: Node, *, symtable: Dict[str, Any]) -> List[str]: missing_symbols = [] for v in values.values(): - missing_symbols.extend( - SymtableValidator.apply(v, symtable=values["symtable_"]) - ) + missing_symbols.extend(SymtableValidator.apply(v, symtable=values["symtable_"])) if len(missing_symbols) > 0: raise ValueError("Symbols {} not found.".format(missing_symbols)) diff --git a/tests/iterator_tests/conftest.py b/tests/iterator_tests/conftest.py new file mode 100644 index 0000000000..07218b13aa --- /dev/null +++ b/tests/iterator_tests/conftest.py @@ -0,0 +1,20 @@ +import pytest + + +@pytest.fixture(params=[False, True], ids=lambda p: f"use_tmps={p}") +def use_tmps(request): + return request.param + + +@pytest.fixture( + params=[ + # (backend, do_validate) + ("lisp", False), + ("cpptoy", False), + ("embedded", True), + ("double_roundtrip", True), + ], + ids=lambda p: f"backend={p[0]}", +) +def backend(request): + return request.param diff --git a/tests/iterator_tests/fvm_nabla_setup.py b/tests/iterator_tests/fvm_nabla_setup.py index b8f01d6cee..e4d7379583 100644 --- a/tests/iterator_tests/fvm_nabla_setup.py +++ b/tests/iterator_tests/fvm_nabla_setup.py @@ -11,18 +11,19 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import math + +import numpy as np from atlas4py import ( - StructuredGrid, - Topology, Config, + StructuredGrid, StructuredMeshGenerator, - functionspace, + Topology, build_edges, - build_node_to_edge_connectivity, build_median_dual_mesh, + build_node_to_edge_connectivity, + functionspace, ) -import numpy as np -import math def assert_close(expected, actual): @@ -30,13 +31,16 @@ def assert_close(expected, actual): class nabla_setup: + @staticmethod def _default_config(): config = Config() config["triangulate"] = True config["angle"] = 20.0 return config - def __init__(self, *, grid=StructuredGrid("O32"), config=_default_config()): + def __init__(self, *, grid=StructuredGrid("O32"), config=None): + if config is None: + config = self._default_config() mesh = StructuredMeshGenerator(config).generate(grid) fs_edges = functionspace.EdgeColumns(mesh, halo=1) diff --git a/tests/iterator_tests/hdiff_reference.py b/tests/iterator_tests/hdiff_reference.py index e0f3bb31df..9788746fde 100644 --- a/tests/iterator_tests/hdiff_reference.py +++ b/tests/iterator_tests/hdiff_reference.py @@ -12,8 +12,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later -import pytest import numpy as np +import pytest def hdiff_reference_impl(): diff --git a/tests/iterator_tests/test_anton_toy.py b/tests/iterator_tests/test_anton_toy.py index 5b9da41c4d..1ba8ba4865 100644 --- a/tests/iterator_tests/test_anton_toy.py +++ b/tests/iterator_tests/test_anton_toy.py @@ -1,8 +1,9 @@ -from iterator.builtins import * -from iterator.embedded import np_as_located_field -from iterator.runtime import * import numpy as np +from iterator.builtins import deref, domain, lift, named_range, shift +from iterator.embedded import np_as_located_field +from iterator.runtime import CartesianAxis, closure, fendef, fundef, offset + @fundef def ldif(d): @@ -11,13 +12,13 @@ def ldif(d): @fundef def rdif(d): - # return compose(ldif(d), shift(d, 1)) + # return compose(ldif(d), shift(d, 1)) # noqa: E800 return lambda inp: ldif(d)(shift(d, 1)(inp)) @fundef def dif2(d): - # return compose(ldif(d), lift(rdif(d))) + # return compose(ldif(d), lift(rdif(d))) # noqa: E800 return lambda inp: ldif(d)(lift(rdif(d))(inp)) @@ -35,20 +36,16 @@ def lap(inp): KDim = CartesianAxis("KDim") -@fendef -def fencil(x, y, z, output, input): +@fendef(offset_provider={"i": IDim, "j": JDim}) +def fencil(x, y, z, out, inp): closure( domain(named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z)), lap, - [output], - [input], + [out], + [inp], ) -fencil(*([None] * 5), backend="lisp") -fencil(*([None] * 5), backend="cpptoy") - - def naive_lap(inp): shape = [inp.shape[0] - 2, inp.shape[1] - 2, inp.shape[2]] out = np.zeros(shape) @@ -61,7 +58,8 @@ def naive_lap(inp): return out -def test_anton_toy(): +def test_anton_toy(backend, use_tmps): + backend, validate = backend shape = [5, 7, 9] rng = np.random.default_rng() inp = np_as_located_field(IDim, JDim, KDim, origin={IDim: 1, JDim: 1, KDim: 0})( @@ -76,8 +74,9 @@ def test_anton_toy(): shape[2], out, inp, - backend="double_roundtrip", - offset_provider={"i": IDim, "j": JDim}, + backend=backend, + use_tmps=use_tmps, ) - assert np.allclose(out, ref) + if validate: + assert np.allclose(out, ref) diff --git a/tests/iterator_tests/test_cartesian_offset_provider.py b/tests/iterator_tests/test_cartesian_offset_provider.py index de198b23bd..88ee0405a6 100644 --- a/tests/iterator_tests/test_cartesian_offset_provider.py +++ b/tests/iterator_tests/test_cartesian_offset_provider.py @@ -1,8 +1,9 @@ +import numpy as np + from iterator.builtins import * +from iterator.embedded import np_as_located_field from iterator.runtime import * -from iterator.embedded import np_as_located_field -import numpy as np I = offset("I") J = offset("J") diff --git a/tests/iterator_tests/test_column_stencil.py b/tests/iterator_tests/test_column_stencil.py index b8458ac3cb..6e87424103 100644 --- a/tests/iterator_tests/test_column_stencil.py +++ b/tests/iterator_tests/test_column_stencil.py @@ -1,7 +1,10 @@ +import numpy as np +import pytest + +from iterator.builtins import * from iterator.embedded import np_as_located_field from iterator.runtime import * -from iterator.builtins import * -import numpy as np + I = offset("I") K = offset("K") @@ -26,7 +29,8 @@ def fencil(i_size, k_size, inp, out): ) -def test_column_stencil(): +def test_column_stencil(backend, use_tmps): + backend, validate = backend shape = [5, 7] inp = np_as_located_field(IDim, KDim)( np.fromfunction(lambda i, k: i * 10 + k, [shape[0] + 1, shape[1] + 1]) @@ -35,12 +39,22 @@ def test_column_stencil(): ref = np.asarray(inp)[1:, 1:] - fencil(shape[0], shape[1], inp, out, offset_provider={"I": IDim, "K": KDim}) + fencil( + shape[0], + shape[1], + inp, + out, + offset_provider={"I": IDim, "K": KDim}, + backend=backend, + use_tmps=use_tmps, + ) - assert np.allclose(ref, out) + if validate: + assert np.allclose(ref, out) -def test_column_stencil_with_k_origin(): +def test_column_stencil_with_k_origin(backend, use_tmps): + backend, validate = backend shape = [5, 7] raw_inp = np.fromfunction(lambda i, k: i * 10 + k, [shape[0] + 1, shape[1] + 2]) inp = np_as_located_field(IDim, KDim, origin={IDim: 0, KDim: 1})(raw_inp) @@ -48,9 +62,18 @@ def test_column_stencil_with_k_origin(): ref = np.asarray(inp)[1:, 2:] - fencil(shape[0], shape[1], inp, out, offset_provider={"I": IDim, "K": KDim}) + fencil( + shape[0], + shape[1], + inp, + out, + offset_provider={"I": IDim, "K": KDim}, + backend=backend, + use_tmps=use_tmps, + ) - assert np.allclose(ref, out) + if validate: + assert np.allclose(ref, out) @fundef @@ -73,7 +96,10 @@ def ksum_fencil(i_size, k_size, inp, out): ) -def test_ksum_scan(): +def test_ksum_scan(backend, use_tmps): + if use_tmps: + pytest.xfail("use_tmps currently not supported for scans") + backend, validate = backend shape = [1, 7] inp = np_as_located_field(IDim, KDim)(np.asarray([list(range(7))])) out = np_as_located_field(IDim, KDim)(np.zeros(shape)) @@ -86,10 +112,12 @@ def test_ksum_scan(): inp, out, offset_provider={"I": IDim, "K": KDim}, - backend="double_roundtrip", + backend=backend, + use_tmps=use_tmps, ) - assert np.allclose(ref, np.asarray(out)) + if validate: + assert np.allclose(ref, np.asarray(out)) @fundef @@ -107,7 +135,10 @@ def ksum_back_fencil(i_size, k_size, inp, out): ) -def test_ksum_back_scan(): +def test_ksum_back_scan(backend, use_tmps): + if use_tmps: + pytest.xfail("use_tmps currently not supported for scans") + backend, validate = backend shape = [1, 7] inp = np_as_located_field(IDim, KDim)(np.asarray([list(range(7))])) out = np_as_located_field(IDim, KDim)(np.zeros(shape)) @@ -120,7 +151,9 @@ def test_ksum_back_scan(): inp, out, offset_provider={"I": IDim, "K": KDim}, - backend="double_roundtrip", + backend=backend, + use_tmps=use_tmps, ) - assert np.allclose(ref, np.asarray(out)) + if validate: + assert np.allclose(ref, np.asarray(out)) diff --git a/tests/iterator_tests/test_fvm_nabla.py b/tests/iterator_tests/test_fvm_nabla.py index 0c18b6b4d8..4ab209a4e7 100644 --- a/tests/iterator_tests/test_fvm_nabla.py +++ b/tests/iterator_tests/test_fvm_nabla.py @@ -11,18 +11,16 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import itertools +import numpy as np +import pytest +from iterator import library from iterator.atlas_utils import AtlasTable +from iterator.builtins import * from iterator.embedded import NeighborTableOffsetProvider, index_field, np_as_located_field from iterator.runtime import * -from iterator.builtins import * -from iterator import library -from .fvm_nabla_setup import ( - assert_close, - nabla_setup, -) -import numpy as np + +from .fvm_nabla_setup import assert_close, nabla_setup Vertex = CartesianAxis("Vertex") @@ -90,7 +88,10 @@ def nabla( ) -def test_compute_zavgS(): +def test_compute_zavgS(backend, use_tmps): + if use_tmps: + pytest.xfail("use_tmps currently only supported for cartesian") + backend, validate = backend setup = nabla_setup() pp = np_as_located_field(Vertex)(setup.input_field) @@ -108,8 +109,9 @@ def test_compute_zavgS(): offset_provider={"E2V": e2v}, ) - assert_close(-199755464.25741270, min(zavgS)) - assert_close(388241977.58389181, max(zavgS)) + if validate: + assert_close(-199755464.25741270, min(zavgS)) + assert_close(388241977.58389181, max(zavgS)) compute_zavgS_fencil( setup.edges_size, @@ -118,11 +120,15 @@ def test_compute_zavgS(): S_MYY, offset_provider={"E2V": e2v}, ) - assert_close(-1000788897.3202186, min(zavgS)) - assert_close(1000788897.3202186, max(zavgS)) + if validate: + assert_close(-1000788897.3202186, min(zavgS)) + assert_close(1000788897.3202186, max(zavgS)) -def test_nabla(): +def test_nabla(backend, use_tmps): + if use_tmps: + pytest.xfail("use_tmps currently only supported for cartesian") + backend, validate = backend setup = nabla_setup() sign = np_as_located_field(Vertex, V2E)(setup.sign_field) @@ -145,14 +151,16 @@ def test_nabla(): S_MYY, sign, vol, - backend="double_roundtrip", offset_provider={"E2V": e2v, "V2E": v2e}, + backend=backend, + use_tmps=use_tmps, ) - assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) - assert_close(3.5455427772565435e-003, max(pnabla_MXX)) - assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) - assert_close(3.3540113705465301e-003, max(pnabla_MYY)) + if validate: + assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) + assert_close(3.5455427772565435e-003, max(pnabla_MXX)) + assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) + assert_close(3.3540113705465301e-003, max(pnabla_MYY)) @fundef @@ -163,7 +171,7 @@ def sign(node_indices, is_pole_edge): def sign_impl(node_index): def impl2(node_indices, is_pole_edge): return if_( - or_(deref(is_pole_edge), node_index == deref(shift(E2V, 0)(node_indices))), + or_(deref(is_pole_edge), eq(node_index, deref(shift(E2V, 0)(node_indices)))), 1.0, -1.0, ) @@ -198,7 +206,11 @@ def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_ ) -def test_nabla_sign(): +def test_nabla_sign(backend, use_tmps): + if use_tmps: + pytest.xfail("use_tmps currently only supported for cartesian") + + backend, validate = backend setup = nabla_setup() # sign = np_as_located_field(Vertex, V2E)(setup.sign_field) @@ -223,13 +235,13 @@ def test_nabla_sign(): vol, index_field(Vertex), is_pole_edge, - # backend="embedded", - backend="double_roundtrip", offset_provider={"E2V": e2v, "V2E": v2e}, - # debug=True, + backend=backend, + use_tmps=use_tmps, ) - assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) - assert_close(3.5455427772565435e-003, max(pnabla_MXX)) - assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) - assert_close(3.3540113705465301e-003, max(pnabla_MYY)) + if validate: + assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) + assert_close(3.5455427772565435e-003, max(pnabla_MXX)) + assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) + assert_close(3.3540113705465301e-003, max(pnabla_MYY)) diff --git a/tests/iterator_tests/test_hdiff.py b/tests/iterator_tests/test_hdiff.py index 0d086bcf6f..56688f6195 100644 --- a/tests/iterator_tests/test_hdiff.py +++ b/tests/iterator_tests/test_hdiff.py @@ -1,9 +1,12 @@ +import numpy as np + from iterator.builtins import * -from iterator.runtime import * from iterator.embedded import np_as_located_field -import numpy as np +from iterator.runtime import * + from .hdiff_reference import hdiff_reference + I = offset("I") J = offset("J") @@ -51,11 +54,8 @@ def hdiff(inp, coeff, out, x, y): ) -hdiff(*([None] * 5), backend="lisp") -hdiff(*([None] * 5), backend="cpptoy") - - -def test_hdiff(hdiff_reference): +def test_hdiff(hdiff_reference, backend, use_tmps): + backend, validate = backend inp, coeff, out = hdiff_reference shape = (out.shape[0], out.shape[1]) @@ -63,8 +63,7 @@ def test_hdiff(hdiff_reference): coeff_s = np_as_located_field(IDim, JDim)(coeff[:, :, 0]) out_s = np_as_located_field(IDim, JDim)(np.zeros_like(coeff[:, :, 0])) - # hdiff(inp_s, coeff_s, out_s, shape[0], shape[1]) - # hdiff(inp_s, coeff_s, out_s, shape[0], shape[1], backend="embedded") - hdiff(inp_s, coeff_s, out_s, shape[0], shape[1], backend="double_roundtrip") + hdiff(inp_s, coeff_s, out_s, shape[0], shape[1], backend=backend, use_tmps=use_tmps) - assert np.allclose(out[:, :, 0], out_s) + if validate: + assert np.allclose(out[:, :, 0], out_s) diff --git a/tests/iterator_tests/test_horizontal_indirection.py b/tests/iterator_tests/test_horizontal_indirection.py index c02710b467..ac5496fe57 100644 --- a/tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/iterator_tests/test_horizontal_indirection.py @@ -10,11 +10,13 @@ # ) # ) # ) +import numpy as np from numpy.core.numeric import allclose + +from iterator.builtins import * from iterator.embedded import np_as_located_field from iterator.runtime import * -from iterator.builtins import * -import numpy as np + I = offset("I") diff --git a/tests/iterator_tests/test_popup_tmps.py b/tests/iterator_tests/test_popup_tmps.py new file mode 100644 index 0000000000..012f1b085d --- /dev/null +++ b/tests/iterator_tests/test_popup_tmps.py @@ -0,0 +1,156 @@ +from iterator import ir +from iterator.transforms.popup_tmps import PopupTmps + + +def test_trivial_single_lift(): + testee = ir.FunCall( + fun=ir.Lambda( + params=[ir.Sym(id="bar_inp")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ + ir.FunCall( + fun=ir.FunCall( + fun=ir.SymRef(id="lift"), + args=[ + ir.Lambda( + params=[ir.Sym(id="foo_inp")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ir.SymRef(id="foo_inp")], + ), + ) + ], + ), + args=[ir.SymRef(id="bar_inp")], + ) + ], + ), + ), + args=[ir.SymRef(id="inp")], + ) + expected = ir.FunCall( + fun=ir.Lambda( + params=[ir.Sym(id="bar_inp"), ir.Sym(id="t0")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ir.SymRef(id="t0")], + ), + ), + args=[ + ir.SymRef(id="inp"), + ir.FunCall( + fun=ir.FunCall( + fun=ir.SymRef(id="lift"), + args=[ + ir.Lambda( + params=[ir.Sym(id="foo_inp")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ir.SymRef(id="foo_inp")], + ), + ) + ], + ), + args=[ir.SymRef(id="inp")], + ), + ], + ) + actual = PopupTmps().visit(testee) + assert actual == expected + + +def test_trivial_multiple_lifts(): + testee = ir.FunCall( + fun=ir.Lambda( + params=[ir.SymRef(id="baz_inp")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ + ir.FunCall( + fun=ir.FunCall( + fun=ir.SymRef(id="lift"), + args=[ + ir.Lambda( + params=[ir.Sym(id="bar_inp")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ + ir.FunCall( + fun=ir.FunCall( + fun=ir.SymRef(id="lift"), + args=[ + ir.Lambda( + params=[ir.Sym(id="foo_inp")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ir.SymRef(id="foo_inp")], + ), + ) + ], + ), + args=[ir.SymRef(id="bar_inp")], + ) + ], + ), + ) + ], + ), + args=[ir.SymRef(id="baz_inp")], + ) + ], + ), + ), + args=[ir.SymRef(id="inp")], + ) + expected = ir.FunCall( + fun=ir.Lambda( + params=[ir.SymRef(id="baz_inp"), ir.SymRef(id="t1")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ir.SymRef(id="t1")], + ), + ), + args=[ + ir.SymRef(id="inp"), + ir.FunCall( + fun=ir.FunCall( + fun=ir.SymRef(id="lift"), + args=[ + ir.Lambda( + params=[ + ir.Sym(id="bar_inp"), + ir.Sym(id="t0"), + ], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ + ir.SymRef(id="t0"), + ], + ), + ) + ], + ), + args=[ + ir.SymRef(id="inp"), + ir.FunCall( + fun=ir.FunCall( + fun=ir.SymRef(id="lift"), + args=[ + ir.Lambda( + params=[ir.Sym(id="foo_inp")], + expr=ir.FunCall( + fun=ir.SymRef(id="deref"), + args=[ir.SymRef(id="foo_inp")], + ), + ) + ], + ), + args=[ir.SymRef(id="inp")], + ), + ], + ), + ], + ) + actual = PopupTmps().visit(testee) + assert actual == expected diff --git a/tests/iterator_tests/test_toy_connectivity.py b/tests/iterator_tests/test_toy_connectivity.py index 226eac1eb7..28c47f19fc 100644 --- a/tests/iterator_tests/test_toy_connectivity.py +++ b/tests/iterator_tests/test_toy_connectivity.py @@ -1,13 +1,11 @@ from dataclasses import field + import numpy as np from numpy.core.numeric import allclose -from iterator.runtime import * + from iterator.builtins import * -from iterator.embedded import ( - NeighborTableOffsetProvider, - np_as_located_field, - index_field, -) +from iterator.embedded import NeighborTableOffsetProvider, index_field, np_as_located_field +from iterator.runtime import * Vertex = CartesianAxis("Vertex") @@ -116,14 +114,15 @@ def e2v_sum_fencil(in_edges, out_vertices): ) -def test_sum_edges_to_vertices(): +def test_sum_edges_to_vertices(backend): + backend, validate = backend inp = index_field(Edge) out = np_as_located_field(Vertex)(np.zeros([9])) ref = np.asarray(list(sum(row) for row in v2e_arr)) - e2v_sum_fencil(inp, out, backend="double_roundtrip") - assert allclose(out, ref) - e2v_sum_fencil(None, None, backend="cpptoy") + e2v_sum_fencil(inp, out, backend=backend) + if validate: + assert allclose(out, ref) @fundef @@ -141,14 +140,15 @@ def e2v_sum_fencil_reduce(in_edges, out_vertices): ) -def test_sum_edges_to_vertices_reduce(): +def test_sum_edges_to_vertices_reduce(backend): + backend, validate = backend inp = index_field(Edge) out = np_as_located_field(Vertex)(np.zeros([9])) ref = np.asarray(list(sum(row) for row in v2e_arr)) - e2v_sum_fencil_reduce(None, None, backend="cpptoy") - e2v_sum_fencil_reduce(inp, out, backend="double_roundtrip") - assert allclose(out, ref) + e2v_sum_fencil_reduce(inp, out, backend=backend) + if validate: + assert allclose(out, ref) @fundef @@ -195,7 +195,8 @@ def sparse_fencil(inp, out): ) -def test_sparse_input_field(): +def test_sparse_input_field(backend): + backend, validate = backend inp = np_as_located_field(Vertex, V2E)(np.asarray([[1, 2, 3, 4]] * 9)) out = np_as_located_field(Vertex)(np.zeros([9])) @@ -204,17 +205,19 @@ def test_sparse_input_field(): sparse_fencil( inp, out, - backend="double_roundtrip", + backend=backend, offset_provider={"V2E": NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, ) - assert allclose(out, ref) + if validate: + assert allclose(out, ref) V2V = offset("V2V") -def test_sparse_input_field_v2v(): +def test_sparse_input_field_v2v(backend): + backend, validate = backend inp = np_as_located_field(Vertex, V2V)(v2v_arr) out = np_as_located_field(Vertex)(np.zeros([9])) @@ -223,11 +226,12 @@ def test_sparse_input_field_v2v(): sparse_fencil( inp, out, - backend="double_roundtrip", + backend=backend, offset_provider={"V2V": NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, ) - assert allclose(out, ref) + if validate: + assert allclose(out, ref) @fundef @@ -245,7 +249,8 @@ def lift_fencil(inp, out): closure(domain(named_range(Vertex, 0, 9)), lift_stencil, [out], [inp]) -def test_lift(): +def test_lift(backend): + backend, validate = backend inp = index_field(Vertex) out = np_as_located_field(Vertex)(np.zeros([9])) ref = np.asarray(np.asarray(range(9))) @@ -254,10 +259,11 @@ def test_lift(): lift_fencil( inp, out, - backend="double_roundtrip", + backend=backend, offset_provider={"V2V": NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, ) - assert allclose(out, ref) + if validate: + assert allclose(out, ref) @fundef @@ -275,7 +281,8 @@ def sparse_shifted_fencil(inp, out): ) -def test_shift_sparse_input_field(): +def test_shift_sparse_input_field(backend): + backend, validate = backend inp = np_as_located_field(Vertex, V2V)(v2v_arr) out = np_as_located_field(Vertex)(np.zeros([9])) ref = np.asarray(np.asarray(range(9))) @@ -283,11 +290,12 @@ def test_shift_sparse_input_field(): sparse_shifted_fencil( inp, out, - backend="double_roundtrip", + backend=backend, offset_provider={"V2V": NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, ) - assert allclose(out, ref) + if validate: + assert allclose(out, ref) @fundef @@ -316,7 +324,8 @@ def sparse_shifted_fencil2(inp_sparse, inp, out1, out2): ) -def test_shift_sparse_input_field2(): +def test_shift_sparse_input_field2(backend): + backend, validate = backend inp = index_field(Vertex) inp_sparse = np_as_located_field(Edge, E2V)(e2v_arr) out1 = np_as_located_field(Vertex)(np.zeros([9])) @@ -327,14 +336,15 @@ def test_shift_sparse_input_field2(): inp, out1, out2, - backend="double_roundtrip", + backend=backend, offset_provider={ "E2V": NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), "V2E": NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), }, ) - assert allclose(out1, out2) + if validate: + assert allclose(out1, out2) @fundef @@ -356,7 +366,8 @@ def sparse_shifted_fencil_reduce(inp, out): ) -def test_shift_sparse_input_field(): +def test_shift_sparse_input_field(backend): + backend, validate = backend inp = np_as_located_field(Vertex, V2V)(v2v_arr) out = np_as_located_field(Vertex)(np.zeros([9])) @@ -372,8 +383,9 @@ def test_shift_sparse_input_field(): sparse_shifted_fencil_reduce( inp, out, - backend="double_roundtrip", + backend=backend, offset_provider={"V2V": NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, ) - assert allclose(np.asarray(out), ref) + if validate: + assert allclose(np.asarray(out), ref) diff --git a/tests/iterator_tests/test_trivial.py b/tests/iterator_tests/test_trivial.py new file mode 100644 index 0000000000..a31c110e19 --- /dev/null +++ b/tests/iterator_tests/test_trivial.py @@ -0,0 +1,53 @@ +import numpy as np + +from iterator.builtins import * +from iterator.embedded import np_as_located_field +from iterator.runtime import * + + +I = offset("I") +J = offset("J") + +IDim = CartesianAxis("IDim") +JDim = CartesianAxis("JDim") + + +@fundef +def foo(foo_inp): + return deref(foo_inp) + + +@fundef +def bar(bar_inp): + return deref(lift(foo)(bar_inp)) + + +@fundef +def baz(baz_inp): + return deref(lift(bar)(baz_inp)) + + +@fendef(offset_provider={"I": IDim, "J": JDim}) +def foobarbaz(inp, out, x, y): + closure( + domain(named_range(IDim, 0, x), named_range(JDim, 0, y)), + baz, + [out], + [inp], + ) + + +def test_trivial(backend, use_tmps): + backend, validate = backend + rng = np.random.default_rng() + inp = rng.uniform(size=(5, 7, 9)) + out = np.copy(inp) + shape = (out.shape[0], out.shape[1]) + + inp_s = np_as_located_field(IDim, JDim, origin={IDim: 0, JDim: 0})(inp[:, :, 0]) + out_s = np_as_located_field(IDim, JDim)(np.zeros_like(inp[:, :, 0])) + + foobarbaz(inp_s, out_s, inp.shape[0], inp.shape[1], backend=backend, use_tmps=use_tmps) + + if validate: + assert np.allclose(out[:, :, 0], out_s) diff --git a/tests/iterator_tests/test_vertical_advection.py b/tests/iterator_tests/test_vertical_advection.py index 0f649f278f..76349bd1cf 100644 --- a/tests/iterator_tests/test_vertical_advection.py +++ b/tests/iterator_tests/test_vertical_advection.py @@ -1,8 +1,9 @@ +import numpy as np +import pytest + from iterator.builtins import * from iterator.embedded import np_as_located_field from iterator.runtime import * -import numpy as np -import pytest @fundef @@ -96,7 +97,10 @@ def fen_solve_tridiag(i_size, j_size, k_size, a, b, c, d, x): ) -def test_tridiag(tridiag_reference): +def test_tridiag(tridiag_reference, backend, use_tmps): + if use_tmps: + pytest.xfail("use_tmps currently not supported for scans") + backend, validate = backend a, b, c, d, x = tridiag_reference shape = a.shape as_3d_field = np_as_located_field(IDim, JDim, KDim) @@ -117,8 +121,9 @@ def test_tridiag(tridiag_reference): x_s, offset_provider={}, column_axis=KDim, - backend="double_roundtrip", - # debug=True, + backend=backend, + use_tmps=use_tmps, ) - assert np.allclose(x, np.asarray(x_s)) + if validate: + assert np.allclose(x, np.asarray(x_s))