Skip to content
This repository was archived by the owner on Jun 14, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 173 additions & 5 deletions dagrt/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@

from immutabledict import immutabledict

from pymbolic.imperative.statement import (
ConditionalAssignment as AssignBase, ConditionalStatement as StatementBase,
Nop as NopBase)
from pytools import RecordWithoutPickling, memoize_method, natsorted

from dagrt.utils import get_variables
Expand Down Expand Up @@ -176,7 +173,177 @@ def print_stmt(stmt):
# }}}


class Statement(StatementBase):
# {{{ statement classes

# copy-pasted from pymbolic

class StatementBase(RecordWithoutPickling):
"""
.. attribute:: depends_on

A :class:`frozenset` of instruction ids that are reuqired to be
executed within this execution context before this instruction can be
executed.

.. attribute:: id

A string, a unique identifier for this instruction.

.. automethod:: get_written_variables
.. automethod:: get_read_variables
"""

def __init__(self, **kwargs):
id = kwargs.pop("id", None)
if id is not None:
id = intern(id)

depends_on = frozenset(kwargs.pop("depends_on", []))
super().__init__(
id=id,
depends_on=depends_on,
**kwargs)

def get_written_variables(self):
"""Returns a :class:`frozenset` of variables being written by this
instruction.
"""
return frozenset()

def get_read_variables(self):
"""Returns a :class:`frozenset` of variables being read by this
instruction.
"""
return frozenset()

def map_expressions(self, mapper, include_lhs=True):
"""Returns a new copy of *self* with all expressions
replaced by ``mapepr(expr)`` for every
:class:`pymbolic.primitives.Expression`
contained in *self*.
"""
return self

def get_dependency_mapper(self, include_calls="descend_args"):
from pymbolic.mapper.dependency import DependencyMapper
return DependencyMapper(
include_subscripts=False,
include_lookups=False,
include_calls=include_calls)

# }}}


# {{{ statement with condition

class ConditionalStatementBase(StatementBase):
__doc__ = StatementBase.__doc__ + """
.. attribute:: condition

The instruction condition as a :mod:`pymbolic` expression (`True` if the
instruction is unconditionally executed)
"""

def __init__(self, **kwargs):
condition = kwargs.pop("condition", True)
super().__init__(
condition=condition,
**kwargs)

def _condition_printing_suffix(self):
if self.condition is True:
return ""
return " if " + str(self.condition)

def __str__(self):
return (super().__str__()
+ self._condition_printing_suffix())

def get_read_variables(self):
dep_mapper = self.get_dependency_mapper()
return (
super().get_read_variables()
| frozenset(
dep.name for dep in dep_mapper(self.condition)))

# }}}


# {{{ assignment

class AssignBase(StatementBase):
"""
.. attribute:: lhs
.. attribute:: rhs
"""

def __init__(self, lhs, rhs, **kwargs):
super().__init__(
lhs=lhs,
rhs=rhs,
**kwargs)

def get_written_variables(self):
from pymbolic.primitives import Subscript, Variable
if isinstance(self.lhs, Variable):
return frozenset([self.lhs.name])
elif isinstance(self.lhs, Subscript):
assert isinstance(self.lhs.aggregate, Variable)
return frozenset([self.lhs.aggregate.name])
else:
raise TypeError("unexpected type of LHS")

def get_read_variables(self):
result = super().get_read_variables()
get_deps = self.get_dependency_mapper()

def get_vars(expr):
return frozenset(dep.name for dep in get_deps(self.rhs))

result = get_vars(self.rhs) | get_vars(self.lhs)

return result

def map_expressions(self, mapper, include_lhs=True):
return (super()
.map_expressions(mapper, include_lhs=include_lhs)
.copy(
lhs=mapper(self.lhs) if include_lhs else self.lhs,
rhs=mapper(self.rhs)))

def __str__(self):
result = "{assignee} <- {expr}".format(
assignee=str(self.lhs),
expr=str(self.rhs),)

return result

# }}}


# {{{ conditional assignment

class ConditionalAssignment(ConditionalStatementBase, AssignBase):
def map_expressions(self, mapper, include_lhs=True):
return (super()
.map_expressions(mapper, include_lhs=include_lhs)
.copy(condition=mapper(self.condition)))

# }}}


# {{{ nop

class NopBase(StatementBase):
def __str__(self):
return "nop"

# }}}


# end pymbolic copy-pasta

class Statement(ConditionalStatementBase):
def get_dependency_mapper(self, include_calls="descend_args"):
from dagrt.expression import ExtendedDependencyMapper
return ExtendedDependencyMapper(
Expand Down Expand Up @@ -243,7 +410,8 @@ def __init__(self, assignee=None, assignee_subscript=None, expression=None,
lhs=lhs,
rhs=flatten(rhs),
loops=loops,
**kwargs)
**kwargs
)

@property
def assignee(self):
Expand Down
Loading