From e99d2424ab130cb9a8210758b82c6811e26c5648 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 13 Jun 2025 12:14:55 -0500 Subject: [PATCH] Copy-paste old pymbolic.imperative to avoid compat hacks --- dagrt/language.py | 178 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 173 insertions(+), 5 deletions(-) diff --git a/dagrt/language.py b/dagrt/language.py index d007787..49962c4 100644 --- a/dagrt/language.py +++ b/dagrt/language.py @@ -29,9 +29,6 @@ from immutabledict import immutabledict -from pymbolic.imperative.statement import ( - ConditionalAssignment as AssignBase, ConditionalStatement as StatementBase, - Nop as NopBase) from pytools import RecordWithoutPickling, memoize_method, natsorted from dagrt.utils import get_variables @@ -176,7 +173,177 @@ def print_stmt(stmt): # }}} -class Statement(StatementBase): +# {{{ statement classes + +# copy-pasted from pymbolic + +class StatementBase(RecordWithoutPickling): + """ + .. attribute:: depends_on + + A :class:`frozenset` of instruction ids that are reuqired to be + executed within this execution context before this instruction can be + executed. + + .. attribute:: id + + A string, a unique identifier for this instruction. + + .. automethod:: get_written_variables + .. automethod:: get_read_variables + """ + + def __init__(self, **kwargs): + id = kwargs.pop("id", None) + if id is not None: + id = intern(id) + + depends_on = frozenset(kwargs.pop("depends_on", [])) + super().__init__( + id=id, + depends_on=depends_on, + **kwargs) + + def get_written_variables(self): + """Returns a :class:`frozenset` of variables being written by this + instruction. + """ + return frozenset() + + def get_read_variables(self): + """Returns a :class:`frozenset` of variables being read by this + instruction. + """ + return frozenset() + + def map_expressions(self, mapper, include_lhs=True): + """Returns a new copy of *self* with all expressions + replaced by ``mapepr(expr)`` for every + :class:`pymbolic.primitives.Expression` + contained in *self*. + """ + return self + + def get_dependency_mapper(self, include_calls="descend_args"): + from pymbolic.mapper.dependency import DependencyMapper + return DependencyMapper( + include_subscripts=False, + include_lookups=False, + include_calls=include_calls) + +# }}} + + +# {{{ statement with condition + +class ConditionalStatementBase(StatementBase): + __doc__ = StatementBase.__doc__ + """ + .. attribute:: condition + + The instruction condition as a :mod:`pymbolic` expression (`True` if the + instruction is unconditionally executed) + """ + + def __init__(self, **kwargs): + condition = kwargs.pop("condition", True) + super().__init__( + condition=condition, + **kwargs) + + def _condition_printing_suffix(self): + if self.condition is True: + return "" + return " if " + str(self.condition) + + def __str__(self): + return (super().__str__() + + self._condition_printing_suffix()) + + def get_read_variables(self): + dep_mapper = self.get_dependency_mapper() + return ( + super().get_read_variables() + | frozenset( + dep.name for dep in dep_mapper(self.condition))) + +# }}} + + +# {{{ assignment + +class AssignBase(StatementBase): + """ + .. attribute:: lhs + .. attribute:: rhs + """ + + def __init__(self, lhs, rhs, **kwargs): + super().__init__( + lhs=lhs, + rhs=rhs, + **kwargs) + + def get_written_variables(self): + from pymbolic.primitives import Subscript, Variable + if isinstance(self.lhs, Variable): + return frozenset([self.lhs.name]) + elif isinstance(self.lhs, Subscript): + assert isinstance(self.lhs.aggregate, Variable) + return frozenset([self.lhs.aggregate.name]) + else: + raise TypeError("unexpected type of LHS") + + def get_read_variables(self): + result = super().get_read_variables() + get_deps = self.get_dependency_mapper() + + def get_vars(expr): + return frozenset(dep.name for dep in get_deps(self.rhs)) + + result = get_vars(self.rhs) | get_vars(self.lhs) + + return result + + def map_expressions(self, mapper, include_lhs=True): + return (super() + .map_expressions(mapper, include_lhs=include_lhs) + .copy( + lhs=mapper(self.lhs) if include_lhs else self.lhs, + rhs=mapper(self.rhs))) + + def __str__(self): + result = "{assignee} <- {expr}".format( + assignee=str(self.lhs), + expr=str(self.rhs),) + + return result + +# }}} + + +# {{{ conditional assignment + +class ConditionalAssignment(ConditionalStatementBase, AssignBase): + def map_expressions(self, mapper, include_lhs=True): + return (super() + .map_expressions(mapper, include_lhs=include_lhs) + .copy(condition=mapper(self.condition))) + +# }}} + + +# {{{ nop + +class NopBase(StatementBase): + def __str__(self): + return "nop" + +# }}} + + +# end pymbolic copy-pasta + +class Statement(ConditionalStatementBase): def get_dependency_mapper(self, include_calls="descend_args"): from dagrt.expression import ExtendedDependencyMapper return ExtendedDependencyMapper( @@ -243,7 +410,8 @@ def __init__(self, assignee=None, assignee_subscript=None, expression=None, lhs=lhs, rhs=flatten(rhs), loops=loops, - **kwargs) + **kwargs + ) @property def assignee(self):