diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8407a79d18..860b878edb 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7350,7 +7350,13 @@ def aten_normal( @torch_op("aten::normal.float_float", trace_only=True) def aten_normal_float_float( - mean: float, std: float, size: INT64, dtype: int = FLOAT.dtype + mean: float, + std: float, + size: INT64, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TensorType: """normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" @@ -7674,7 +7680,9 @@ def aten_prod(self: TReal, dtype: int = -1) -> TReal: if dtype != -1 and dtype is not None: self = op.Cast(self, to=dtype) - return op.ReduceProd(self) + elif self.dtype.is_integer(): + self = op.Cast(self, to=INT64.dtype) + return op.ReduceProd(self, keepdims=False) @torch_op("aten::prod.dim_int", trace_only=True) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 858c8dfe72..4160d13dab 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1067,7 +1067,7 @@ def _where_input_wrangler( TorchLibOpInfo("prod", core_ops.aten_prod).skip( matcher=lambda sample: sample.kwargs.get("dim") is not None or sample.kwargs.get("keepdim") is not None - or sample.kwargs.get("dtype") != -1, + or len(sample.args) > 0, reason="this Aten overload only accept 1 inputs: self", ), TorchLibOpInfo("prod_dim_int", core_ops.aten_prod_dim_int).skip(