diff --git a/dagrt/builtins_python.py b/dagrt/builtins_python.py index 623776c..ab8eca8 100644 --- a/dagrt/builtins_python.py +++ b/dagrt/builtins_python.py @@ -59,6 +59,11 @@ def builtin_dot_product(a, b): return np.vdot(a, b) +def builtin_elementwise_abs(x): + import numpy as np + return np.abs(x) + + def builtin_array(n): import numpy as np if n != np.floor(n): @@ -141,6 +146,7 @@ def builtin_print(arg): "norm_2": builtin_norm_2, "norm_inf": builtin_norm_inf, "dot_product": builtin_dot_product, + "elementwise_abs": builtin_elementwise_abs, "array": builtin_array, "matmul": builtin_matmul, "transpose": builtin_transpose, diff --git a/dagrt/codegen/fortran.py b/dagrt/codegen/fortran.py index 3327855..63a5c32 100644 --- a/dagrt/codegen/fortran.py +++ b/dagrt/codegen/fortran.py @@ -2407,12 +2407,15 @@ def codegen_builtin_len(results, function, args, arg_kinds, code_generator.emit("") -class AbsComputer(TypeVisitorWithResult): - def visit_BuiltinType(self, fortran_type, fortran_expr_str, index_expr_map): +class AbsComputer(AssignmentEmitter): + def visit_BuiltinType(self, fortran_type, fortran_expr_str, index_expr_map, + rhs_expr, is_rhs_target): + expr = self.code_generator.expr(rhs_expr) self.code_generator.emit( - "{result} = abs({result})" + "{result} = abs({expr})" .format( - result=self.result_expr)) + result=fortran_expr_str, + expr=expr)) def codegen_builtin_elementwise_abs(results, function, args, arg_kinds, @@ -2421,25 +2424,23 @@ def codegen_builtin_elementwise_abs(results, function, args, arg_kinds, from dagrt.data import Scalar, Array, UserType x_kind = arg_kinds[0] - if isinstance(x_kind, Scalar): - if x_kind.is_real_valued: - ftype = BuiltinType("real*8") - else: - ftype = BuiltinType("complex*16") - elif isinstance(x_kind, UserType): - ftype = code_generator.user_type_map[x_kind.identifier] - elif isinstance(x_kind, Array): + if isinstance(x_kind, Scalar) or isinstance(x_kind, Array): code_generator.emit("{result} = abs({arg})".format( result=result, arg=args[0])) return + elif isinstance(x_kind, UserType): + ftype = code_generator.user_type_map[x_kind.identifier] else: raise TypeError("unsupported kind for elementwise_abs argument: %s" % x_kind) - code_generator.emit(f"{result} = 0") - code_generator.emit("") - - AbsComputer(code_generator, result)(ftype, args[0], {}) + # Need to pass argument to assignment emitter as a variable (for mappers) + # Call it a target to avoid name mangling + from pymbolic import var + argvar = var("" + args[0]) + code_generator.sym_kind_table.set( + None, "" + args[0], UserType(x_kind.identifier)) + AbsComputer(code_generator)(ftype, result, {}, argvar, is_rhs_target=False) code_generator.emit("") diff --git a/test/test_codegen_fortran.py b/test/test_codegen_fortran.py index 9d3fd3d..a83b89d 100755 --- a/test/test_codegen_fortran.py +++ b/test/test_codegen_fortran.py @@ -150,19 +150,32 @@ def test_self_dep_in_loop(): fortran_libraries=["lapack", "blas"]) +class AbsFailure: + pass + + def test_elementwise_abs(): with CodeBuilder(name="primary") as cb: - cb("y", "f(0, ytype)") - cb("ytype", "y") - # Test new builtin on a usertype. - cb("z", "elementwise_abs(ytype)") cb("i", "array(20)") cb("i[j]", "-j", loops=(("j", 0, 20),)) # Test new builtin on an array type. cb("k", "elementwise_abs(i)") + with cb.if_("k[20] > 19"): + cb.raise_(AbsFailure) + with cb.if_("k[20] < 19"): + cb.raise_(AbsFailure) # Test new builtin on a scalar. cb("l", "elementwise_abs(-20)") + with cb.if_("l > 20"): + cb.raise_(AbsFailure) + with cb.if_("l < 20"): + cb.raise_(AbsFailure) + cb("y", "f(0, ytype)") + cb("ytype", "y") + # Test new builtin on a usertype. + cb("ytype", "elementwise_abs(ytype)") + # (We check this in the outer test code) code = create_DAGCode_with_steady_phase(cb.statements) diff --git a/test/test_element_abs.f90 b/test/test_element_abs.f90 index a5129a4..7a92ec4 100644 --- a/test/test_element_abs.f90 +++ b/test/test_element_abs.f90 @@ -13,6 +13,8 @@ program test_element_abs real*8, dimension(100) :: y0 integer i + integer stderr + parameter(stderr=0) ! start code ---------------------------------------------------------------- @@ -20,11 +22,17 @@ program test_element_abs do i = 1, 100 - y0 = i + y0(i) = i end do call timestep_initialize(dagrt_state=dagrt_state_ptr, state_ytype=y0) call timestep_run(dagrt_state=dagrt_state_ptr) + ! For the UserType, check that the absolute value did its job. + do i = 1, 100 + if (dagrt_state%state_ytype(i) /= 2*y0(i)) then + write(stderr,*) "UserType elementwise abs failure" + endif + end do call timestep_shutdown(dagrt_state=dagrt_state_ptr) end program