From 751cdf0b4445e1672da0ee63cebcc74904c6980a Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 25 Aug 2021 17:34:40 -0500 Subject: [PATCH 1/3] export Reshape --- pytato/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/__init__.py b/pytato/__init__.py index c9c939f3d..c6d6694e4 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -26,7 +26,7 @@ from pytato.array import ( Array, AbstractResultWithNamedArrays, DictOfNamedArrays, Placeholder, - IndexLambda, NamedArray, DataWrapper, InputArgumentBase, + IndexLambda, NamedArray, DataWrapper, InputArgumentBase, Reshape, make_dict_of_named_arrays, make_placeholder, make_size_param, make_data_wrapper, @@ -68,7 +68,7 @@ __all__ = ( "Array", "AbstractResultWithNamedArrays", "DictOfNamedArrays", "Placeholder", "IndexLambda", "NamedArray", "LoopyCall", - "DataWrapper", "InputArgumentBase", + "DataWrapper", "InputArgumentBase", "Reshape", "make_dict_of_named_arrays", "make_placeholder", "make_size_param", "make_data_wrapper", "einsum", From 8e4a79328a064e971c47df5220ef4dff248ca149 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Fri, 27 Aug 2021 11:54:29 -0500 Subject: [PATCH 2/3] adds EinsumInfo --- pytato/__init__.py | 3 ++- pytato/array.py | 6 +++++- pytato/tags.py | 13 +++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/pytato/__init__.py b/pytato/__init__.py index c6d6694e4..f58a64d68 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -27,6 +27,7 @@ from pytato.array import ( Array, AbstractResultWithNamedArrays, DictOfNamedArrays, Placeholder, IndexLambda, NamedArray, DataWrapper, InputArgumentBase, Reshape, + Einsum, make_dict_of_named_arrays, make_placeholder, make_size_param, make_data_wrapper, @@ -68,7 +69,7 @@ __all__ = ( "Array", "AbstractResultWithNamedArrays", "DictOfNamedArrays", "Placeholder", "IndexLambda", "NamedArray", "LoopyCall", - "DataWrapper", "InputArgumentBase", "Reshape", + "DataWrapper", "InputArgumentBase", "Reshape", "Einsum", "make_dict_of_named_arrays", "make_placeholder", "make_size_param", "make_data_wrapper", "einsum", diff --git a/pytato/array.py b/pytato/array.py index 718f296a4..ebb194080 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1042,6 +1042,9 @@ def einsum(subscripts: str, *operands: Array) -> Einsum: """ Einstein summation *subscripts* on *operands*. """ + + from pytato.tags import EinsumInfo + if len(operands) == 0: raise ValueError("must specify at least one operand") @@ -1074,7 +1077,8 @@ def einsum(subscripts: str, *operands: Array) -> Einsum: index_to_axis_length)) access_descriptors.append(access_descriptor) - return Einsum(tuple(access_descriptors), operands) + return Einsum(tuple(access_descriptors), operands, + tags=frozenset([EinsumInfo(subscripts)])) # }}} diff --git a/pytato/tags.py b/pytato/tags.py index fa1c134c2..a15e36bca 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -94,3 +94,16 @@ class PrefixNamed(_BaseNameTag): prefix: str # }}} + + +# {{{ User operation name + +class UserOpInfo(UniqueTag): + pass + + +@tag_dataclass +class EinsumInfo(UserOpInfo): + spec: str + +# }}} From b8cc63dc3fc9a1d838a8030cfee638610344ddae Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 7 Dec 2021 13:31:13 -0600 Subject: [PATCH 3/3] fix merge error --- pytato/array.py | 55 +------------------------------------------------ 1 file changed, 1 insertion(+), 54 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 9425e9092..c9ef2c597 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1123,61 +1123,8 @@ def einsum(subscripts: str, *operands: Array) -> Einsum: access_descriptors.append(access_descriptor) return Einsum(tuple(access_descriptors), operands, -<<<<<<< HEAD - tags=frozenset([EinsumInfo(subscripts)])) - -# }}} - - -# {{{ matrix product - -class MatrixProduct(Array): - """A product of two matrices, or a matrix and a vector. - - The semantics of this operation follow PEP 465 [pep465]_, i.e., the Python - matmul (@) operator. - - .. attribute:: x1 - .. attribute:: x2 - - .. [pep465] https://www.python.org/dev/peps/pep-0465/ - - """ - _fields = Array._fields + ("x1", "x2") - - _mapper_method = "map_matrix_product" - - def __init__(self, - x1: Array, - x2: Array, - tags: TagsType = frozenset()): - super().__init__(tags) - self.x1 = x1 - self.x2 = x2 - - @property - def shape(self) -> ShapeType: - # FIXME: Broadcasting currently unsupported. - assert 0 < self.x1.ndim <= 2 - assert 0 < self.x2.ndim <= 2 - - if self.x1.ndim == 1 and self.x2.ndim == 1: - return () - elif self.x1.ndim == 1 and self.x2.ndim == 2: - return (self.x2.shape[1],) - elif self.x1.ndim == 2 and self.x2.ndim == 1: - return (self.x1.shape[0],) - elif self.x1.ndim == 2 and self.x2.ndim == 2: - return (self.x1.shape[0], self.x2.shape[1]) - - raise AssertionError() - - @property - def dtype(self) -> np.dtype[Any]: - return _np_result_type(self.x1.dtype, self.x2.dtype) -======= + tags=frozenset([EinsumInfo(subscripts)]), axes=_get_default_axes(len(out_spec))) ->>>>>>> main # }}}