diff --git a/pytato/array.py b/pytato/array.py index a30b288e2..91427c0da 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1087,6 +1087,9 @@ def einsum(subscripts: str, *operands: Array) -> Einsum: """ Einstein summation *subscripts* on *operands*. """ + + from pytato.tags import EinsumInfo + if len(operands) == 0: raise ValueError("must specify at least one operand") @@ -1120,6 +1123,7 @@ def einsum(subscripts: str, *operands: Array) -> Einsum: access_descriptors.append(access_descriptor) return Einsum(tuple(access_descriptors), operands, + tags=frozenset([EinsumInfo(subscripts)]), axes=_get_default_axes(len(out_spec))) # }}} diff --git a/pytato/tags.py b/pytato/tags.py index 601498846..ec39fd479 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -97,6 +97,19 @@ class PrefixNamed(_BaseNameTag): # }}} +# {{{ User operation name + +class UserOpInfo(UniqueTag): + pass + + +@tag_dataclass +class EinsumInfo(UserOpInfo): + spec: str + +# }}} + + @tag_dataclass class AssumeNonNegative(Tag): """