From 42523a57ad603eb429458ed9efbbb02098182aa3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 9 Jan 2026 20:02:45 -0800 Subject: [PATCH 1/5] prod Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b35977a81c..e25fb04ae1 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7331,7 +7331,7 @@ def aten_normal( @torch_op("aten::normal.float_float", trace_only=True) def aten_normal_float_float( - mean: float, std: float, size: INT64, dtype: int = FLOAT.dtype + mean: float, std: float, size: INT64, dtype: int = FLOAT.dtype, layout: str = "", device: str = "", pin_memory: bool = False ) -> TensorType: """normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" @@ -7655,7 +7655,7 @@ def aten_prod(self: TReal, dtype: int = -1) -> TReal: if dtype != -1 and dtype is not None: self = op.Cast(self, to=dtype) - return op.ReduceProd(self) + return op.ReduceProd(self, keepdims=False) @torch_op("aten::prod.dim_int", trace_only=True) From 080ef163622ea33a9b9f98c40ac3eeac23dc2202 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 9 Jan 2026 20:03:41 -0800 Subject: [PATCH 2/5] lint Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 103a0de04c..b169a7a4ee 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7350,7 +7350,13 @@ def aten_normal( @torch_op("aten::normal.float_float", trace_only=True) def aten_normal_float_float( - mean: float, std: float, size: INT64, dtype: int = FLOAT.dtype, layout: str = "", device: str = "", pin_memory: bool = False + mean: float, + std: float, + size: INT64, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TensorType: """normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" From cd21ac2f341581bacfe2dfeadbe0543d24143f7f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 9 Jan 2026 20:21:42 -0800 Subject: [PATCH 3/5] Prod test Signed-off-by: Justin Chu --- tests/function_libs/torch_lib/e2e_ops_tests.py | 14 ++++++++++++++ tests/function_libs/torch_lib/ops_test_data.py | 6 ------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index d344723408..ce1b7d42b0 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -871,6 +871,20 @@ def forward(self, x, index, update): ) _testing.assert_onnx_program(onnx_program) + def test_prod(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.prod(x) + + x = torch.randn(3, 4, 5) + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 858c8dfe72..6fd852519c 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1064,12 +1064,6 @@ def _where_input_wrangler( TorchLibOpInfo("permute", core_ops.aten_permute_complex, complex=True), TorchLibOpInfo("polar", core_ops.aten_polar), TorchLibOpInfo("pow", core_ops.aten_pow), - TorchLibOpInfo("prod", core_ops.aten_prod).skip( - matcher=lambda sample: sample.kwargs.get("dim") is not None - or sample.kwargs.get("keepdim") is not None - or sample.kwargs.get("dtype") != -1, - reason="this Aten overload only accept 1 inputs: self", - ), TorchLibOpInfo("prod_dim_int", core_ops.aten_prod_dim_int).skip( matcher=lambda sample: ( sample.kwargs.get("dim") is None and sample.kwargs.get("keepdim") is None From c61b6dbafa6d5705403e5aefcfef1a5c78b247a0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 9 Jan 2026 20:28:19 -0800 Subject: [PATCH 4/5] test Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 3 +++ tests/function_libs/torch_lib/e2e_ops_tests.py | 14 -------------- tests/function_libs/torch_lib/ops_test_data.py | 6 ++++++ 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b169a7a4ee..0c7fc8acfe 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7680,6 +7680,9 @@ def aten_prod(self: TReal, dtype: int = -1) -> TReal: if dtype != -1 and dtype is not None: self = op.Cast(self, to=dtype) + else: + if self.dtype.is_integer(): + self = op.Cast(self, to=INT64.dtype) return op.ReduceProd(self, keepdims=False) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index ce1b7d42b0..d344723408 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -871,20 +871,6 @@ def forward(self, x, index, update): ) _testing.assert_onnx_program(onnx_program) - def test_prod(self): - class Model(torch.nn.Module): - def forward(self, x): - return torch.prod(x) - - x = torch.randn(3, 4, 5) - onnx_program = torch.onnx.export( - Model(), - (x,), - dynamo=True, - verbose=False, - ) - _testing.assert_onnx_program(onnx_program) - if __name__ == "__main__": unittest.main() diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 6fd852519c..4160d13dab 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1064,6 +1064,12 @@ def _where_input_wrangler( TorchLibOpInfo("permute", core_ops.aten_permute_complex, complex=True), TorchLibOpInfo("polar", core_ops.aten_polar), TorchLibOpInfo("pow", core_ops.aten_pow), + TorchLibOpInfo("prod", core_ops.aten_prod).skip( + matcher=lambda sample: sample.kwargs.get("dim") is not None + or sample.kwargs.get("keepdim") is not None + or len(sample.args) > 0, + reason="this Aten overload only accept 1 inputs: self", + ), TorchLibOpInfo("prod_dim_int", core_ops.aten_prod_dim_int).skip( matcher=lambda sample: ( sample.kwargs.get("dim") is None and sample.kwargs.get("keepdim") is None From a4747c82d04ed17263408d1be87a35b58a35c61a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 9 Jan 2026 20:28:34 -0800 Subject: [PATCH 5/5] int Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 0c7fc8acfe..860b878edb 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7680,9 +7680,8 @@ def aten_prod(self: TReal, dtype: int = -1) -> TReal: if dtype != -1 and dtype is not None: self = op.Cast(self, to=dtype) - else: - if self.dtype.is_integer(): - self = op.Cast(self, to=INT64.dtype) + elif self.dtype.is_integer(): + self = op.Cast(self, to=INT64.dtype) return op.ReduceProd(self, keepdims=False)