diff --git a/examples/advection.py b/examples/advection.py index 0a4e17890..13a9e9dc9 100755 --- a/examples/advection.py +++ b/examples/advection.py @@ -190,6 +190,7 @@ def main(): u = pt.make_placeholder(name="u", shape=(dg_ops.nelements, dg_ops.nnodes), dtype=np.float64) result = op.apply(u) + result = pt.transform.deduplicate(result) prog = pt.generate_loopy(result, cl_device=queue.device) print(prog.program) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 6cbc70620..1213d49af 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -358,16 +358,16 @@ def map_concatenate(self, expr: Concatenate) -> list[ArrayOrNames]: def map_einsum(self, expr: Einsum) -> list[ArrayOrNames]: return self._get_preds_from_shape(expr.shape) + list(expr.args) + def map_loopy_call(self, expr: LoopyCall) -> list[ArrayOrNames]: + return [ary for ary in expr.bindings.values() if isinstance(ary, Array)] + def map_loopy_call_result(self, expr: NamedArray) -> list[ArrayOrNames]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) - return ( - self._get_preds_from_shape(expr.shape) - + [ - ary - for ary in expr._container.bindings.values() - if isinstance(ary, Array)]) + return [ + *self._get_preds_from_shape(expr.shape), + expr._container] def _map_index_base(self, expr: IndexBase) -> list[ArrayOrNames]: return ( diff --git a/pytato/loopy.py b/pytato/loopy.py index 9fdb1c338..dbe018d14 100644 --- a/pytato/loopy.py +++ b/pytato/loopy.py @@ -96,7 +96,7 @@ """ -@array_dataclass() +@array_dataclass(hash=False) class LoopyCall(AbstractResultWithNamedArrays): """ An array expression node representing a call to an entrypoint in a diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 75b7a60f1..aa3bd2838 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -435,7 +435,7 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): """ def __init__( self, - err_on_collision: bool = False, + err_on_collision: bool = __debug__, _cache: CachedMapperCache[ArrayOrNames, ResultT, P] | None = None, _function_cache: @@ -661,8 +661,8 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): """ def __init__( self, - err_on_collision: bool = False, - err_on_created_duplicate: bool = False, + err_on_collision: bool = __debug__, + err_on_created_duplicate: bool = __debug__, _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: @@ -756,8 +756,8 @@ class TransformMapperWithExtraArgs( """ def __init__( self, - err_on_collision: bool = False, - err_on_created_duplicate: bool = False, + err_on_collision: bool = __debug__, + err_on_created_duplicate: bool = __debug__, _cache: TransformMapperCache[ArrayOrNames, P] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, P] | None = None @@ -1931,8 +1931,8 @@ def __init__( self, nsuccessors: Mapping[Array, int], _cache: MPMSMaterializerCache | None = None): - err_on_collision = False - err_on_created_duplicate = False + err_on_collision = __debug__ + err_on_created_duplicate = __debug__ if _cache is None: _cache = MPMSMaterializerCache( diff --git a/test/test_codegen.py b/test/test_codegen.py index 49bc41bb0..23a5b0f6e 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -2137,7 +2137,9 @@ def test_zeros_like(ctx_factory): assert isinstance(zero, pt.Array) assert isinstance(one, pt.Array) - prg = pt.generate_loopy({"zero": zero, "one": one}) + prg = pt.generate_loopy( + pt.transform.deduplicate( + pt.make_dict_of_named_arrays({"zero": zero, "one": one}))) _, pt_out = prg(cq, x=x_in) np.testing.assert_allclose(pt_out["zero"], 0) np.testing.assert_allclose(pt_out["one"], 1)