From a2d3aba02c8aeff1de6b2fa07cfa25465974f2ab Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 31 Dec 2025 20:53:41 +0000 Subject: [PATCH 1/8] Initial plan From 443ff3a144f2c7a060dcf2789621e216d46b6fea Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 31 Dec 2025 21:09:05 +0000 Subject: [PATCH 2/8] Add elem_type inference to split_predict pass - Added inferElemType() helper function to infer missing elem_type from inputs - Modified split_predict to infer elem_type when creating new graph inputs/outputs - This fixes the bug where split_predict generates invalid ONNX models with missing elem_type Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- onnxoptimizer/passes/split.h | 45 +++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/onnxoptimizer/passes/split.h b/onnxoptimizer/passes/split.h index 5a7b97938..b1b94187b 100644 --- a/onnxoptimizer/passes/split.h +++ b/onnxoptimizer/passes/split.h @@ -21,6 +21,27 @@ static constexpr const char* impure_operators[] = { "Scan", }; +// Helper function to infer elem_type for a value if it's not set +static int32_t inferElemType(const Value* v) { + if (v->elemType() != 0) { + return v->elemType(); + } + + // If elem_type is UNDEFINED, try to infer it from the producing node + const Node* producer = v->node(); + + // For many operators, output type matches input type + // Check if any input has a known elem_type + for (const Value* input : producer->inputs()) { + if (input->elemType() != 0) { + return input->elemType(); + } + } + + // Couldn't infer - return UNDEFINED + return 0; +} + static bool is_pure_operator(Node* n) { for (auto x : impure_operators) { if (n->kind() == Symbol(x)) { @@ -127,6 +148,13 @@ static void split_init_and_predict(Graph& graph, bool init, bool predict) { if (v->node()->kind() == kUndefined) { continue; } + // Ensure the value has elem_type set before registering as output + if (v->elemType() == 0) { + int32_t elem_type = inferElemType(v); + if (elem_type != 0) { + v->setElemType(elem_type); + } + } graph.registerOutput(v); } @@ -169,7 +197,22 @@ static void split_init_and_predict(Graph& graph, bool init, bool predict) { if (v->node()->kind() == kUndefined) { v->replaceAllUsesWith(optionalInputDummyNode->outputs()[0]); } else { - Value* newv = graph.addInput()->copyMetadata(v); + Value* newv = graph.addInput(); + // Copy sizes and name first + if (v->has_sizes()) { + newv->setSizes(v->sizes()); + } + if (v->has_unique_name()) { + newv->setUniqueName(v->uniqueName()); + } + // For elem_type, try to infer if not set + int32_t elem_type = inferElemType(v); + if (elem_type != 0) { + newv->setElemType(elem_type); + } + // Note: If elem_type is still UNDEFINED (0), the resulting model + // will be invalid. This indicates the input model lacks proper + // type information for intermediate values. v->replaceAllUsesWith(newv); } } From 510ca220c67509238301ed3ee14877511d73f7ba Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 31 Dec 2025 21:11:03 +0000 Subject: [PATCH 3/8] Add test for split_predict elem_type preservation Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- onnxoptimizer/test/optimizer_test.py | 45 ++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/onnxoptimizer/test/optimizer_test.py b/onnxoptimizer/test/optimizer_test.py index 0f7a8b9cc..9a449ba19 100644 --- a/onnxoptimizer/test/optimizer_test.py +++ b/onnxoptimizer/test/optimizer_test.py @@ -4676,6 +4676,51 @@ def test_rewrite_where(self): assert optimized_model.graph.node[0].input == ['A', 'Y', 'X'] assert optimized_model.graph.node[3].input == ['M', 'X', 'Y'] + def test_split_predict_preserves_elem_type(self): + """ + Test that split_predict pass preserves elem_type information + when creating new inputs/outputs. + + This reproduces the bug where intermediate values without explicit + value_info cause split_predict to generate invalid models with + missing elem_type. + """ + # Create a model where intermediate values don't have value_info + # but should have their types inferred from inputs + model = parser.parse_model(""" + < + ir_version: 7, + opset_import:["": 13] + > + agraph (float[2] X) => (float[2] Y) + { + # Pure operation that can go to init net + one = Constant() + added = Add(X, one) + + # Impure operation that must stay in predict net + # This uses 'added', making it a boundary value + random = RandomUniform() + Y = Add(random, added) + } + """) + + # Optimize with split_predict + optimized_model = self._optimized( + model, ['split_predict'], False, compare_result=False, check=True) + + # Verify the optimized model is valid + # The key check is that checker.check_model() passes, + # which is done by _optimized when check=True + + # Additionally verify all inputs have valid elem_type + for input_val in optimized_model.graph.input: + self.assertIsNotNone(input_val.type) + self.assertTrue(input_val.type.HasField('tensor_type')) + # elem_type 0 is UNDEFINED, which is invalid for inputs + self.assertNotEqual(input_val.type.tensor_type.elem_type, 0, + f"Input {input_val.name} has UNDEFINED elem_type") + if __name__ == "__main__": unittest.main() From 7a94e92a88b33e2de355892761778c0abaf5d86b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 31 Dec 2025 21:12:03 +0000 Subject: [PATCH 4/8] Add C++ test for split_predict elem_type preservation Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- tests/test_simple.cc | 46 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/test_simple.cc b/tests/test_simple.cc index 649497d30..6170cce14 100644 --- a/tests/test_simple.cc +++ b/tests/test_simple.cc @@ -5,6 +5,7 @@ #include #include #include +#include TEST(OptimizerTest, NopReshape) { const char* graph_str = R"( @@ -27,3 +28,48 @@ TEST(OptimizerTest, NopReshape) { ASSERT_EQ(optimized_model.graph().node().size(), 1); ASSERT_EQ(optimized_model.graph().node()[0].op_type(), "Identity"); } + +TEST(OptimizerTest, SplitPredictPreservesElemType) { + // Test that split_predict preserves elem_type for intermediate values + // This reproduces the bug where intermediate values without value_info + // cause split_predict to generate invalid models with missing elem_type + const char* graph_str = R"( + < + ir_version: 7, + opset_import: [ "": 13] + > + agraph (float[2] X) => (float[2] Y) + { + # Pure operation that can go to init net + one = Constant() + added = Add(X, one) + + # Impure operation that must stay in predict net + # This uses 'added', making it a boundary value + random = RandomUniform() + Y = Add(random, added) + } + )"; + + onnx::ModelProto model; + const onnx::Status status = onnx::OnnxParser::Parse(model, graph_str); + EXPECT_TRUE(status.IsOK()); + + // Run split_predict optimization + auto optimized_model = onnx::optimization::Optimize(model, {"split_predict"}); + + // Verify the model is valid - this will catch missing elem_type + try { + onnx::checker::check_model(optimized_model); + } catch (const std::exception& e) { + FAIL() << "Optimized model failed validation: " << e.what(); + } + + // Verify all inputs have valid elem_type (not UNDEFINED/0) + for (const auto& input : optimized_model.graph().input()) { + ASSERT_TRUE(input.has_type()) << "Input " << input.name() << " missing type"; + ASSERT_TRUE(input.type().has_tensor_type()) << "Input " << input.name() << " missing tensor_type"; + ASSERT_NE(input.type().tensor_type().elem_type(), 0) + << "Input " << input.name() << " has UNDEFINED elem_type"; + } +} From e5518b1ab0219fb01a5abf5eb55020c406cf22c2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 31 Dec 2025 21:13:26 +0000 Subject: [PATCH 5/8] Address code review feedback for inferElemType function - Add null check for producer node - Improve documentation to explain limitations and supported operators - Clarify that this is a heuristic for common cases Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- onnxoptimizer/passes/split.h | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/onnxoptimizer/passes/split.h b/onnxoptimizer/passes/split.h index b1b94187b..3b2d1b658 100644 --- a/onnxoptimizer/passes/split.h +++ b/onnxoptimizer/passes/split.h @@ -22,13 +22,21 @@ static constexpr const char* impure_operators[] = { }; // Helper function to infer elem_type for a value if it's not set +// This is a simple heuristic that works for many common operators where +// output type matches input type (e.g., Add, Sub, Mul, etc.). +// For operators with different output types (e.g., Shape, Cast), this +// may return an incorrect type, but that would have been present in the +// original model's value_info if available. static int32_t inferElemType(const Value* v) { if (v->elemType() != 0) { return v->elemType(); } - // If elem_type is UNDEFINED, try to infer it from the producing node + // Check if the value has a producing node const Node* producer = v->node(); + if (!producer) { + return 0; + } // For many operators, output type matches input type // Check if any input has a known elem_type From d5e51ffe352cfa37a1c650b7204274176f2c7a01 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 31 Dec 2025 21:15:25 +0000 Subject: [PATCH 6/8] Improve code comments and documentation - Clarify what happens when elem_type inference fails - Use TensorProto.UNDEFINED constant in Python test instead of magic number - Add comment explaining UNDEFINED value in C++ test Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- onnxoptimizer/passes/split.h | 5 +++-- onnxoptimizer/test/optimizer_test.py | 5 +++-- tests/test_simple.cc | 3 ++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/onnxoptimizer/passes/split.h b/onnxoptimizer/passes/split.h index 3b2d1b658..f705086b2 100644 --- a/onnxoptimizer/passes/split.h +++ b/onnxoptimizer/passes/split.h @@ -219,8 +219,9 @@ static void split_init_and_predict(Graph& graph, bool init, bool predict) { newv->setElemType(elem_type); } // Note: If elem_type is still UNDEFINED (0), the resulting model - // will be invalid. This indicates the input model lacks proper - // type information for intermediate values. + // may be invalid. This indicates the input model lacks proper + // type information for intermediate values. The model validation + // will catch this error during onnx.checker.check_model(). v->replaceAllUsesWith(newv); } } diff --git a/onnxoptimizer/test/optimizer_test.py b/onnxoptimizer/test/optimizer_test.py index 9a449ba19..e20764292 100644 --- a/onnxoptimizer/test/optimizer_test.py +++ b/onnxoptimizer/test/optimizer_test.py @@ -4717,8 +4717,9 @@ def test_split_predict_preserves_elem_type(self): for input_val in optimized_model.graph.input: self.assertIsNotNone(input_val.type) self.assertTrue(input_val.type.HasField('tensor_type')) - # elem_type 0 is UNDEFINED, which is invalid for inputs - self.assertNotEqual(input_val.type.tensor_type.elem_type, 0, + # elem_type must be set (not TensorProto_DataType_UNDEFINED which is 0) + self.assertNotEqual(input_val.type.tensor_type.elem_type, + TensorProto.UNDEFINED, f"Input {input_val.name} has UNDEFINED elem_type") diff --git a/tests/test_simple.cc b/tests/test_simple.cc index 6170cce14..4edc12519 100644 --- a/tests/test_simple.cc +++ b/tests/test_simple.cc @@ -65,10 +65,11 @@ TEST(OptimizerTest, SplitPredictPreservesElemType) { FAIL() << "Optimized model failed validation: " << e.what(); } - // Verify all inputs have valid elem_type (not UNDEFINED/0) + // Verify all inputs have valid elem_type (not UNDEFINED which is 0) for (const auto& input : optimized_model.graph().input()) { ASSERT_TRUE(input.has_type()) << "Input " << input.name() << " missing type"; ASSERT_TRUE(input.type().has_tensor_type()) << "Input " << input.name() << " missing tensor_type"; + // TensorProto_DataType_UNDEFINED has value 0 ASSERT_NE(input.type().tensor_type().elem_type(), 0) << "Input " << input.name() << " has UNDEFINED elem_type"; } From 3b606e467363575a1ded60934f14d3f2935df6a0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 31 Dec 2025 21:18:33 +0000 Subject: [PATCH 7/8] Replace magic number 0 with TensorProto_DataType_UNDEFINED constant - Use named constant throughout split.h instead of magic number 0 - Update C++ test to use ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED - Improves code readability and maintainability Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- onnxoptimizer/passes/split.h | 16 ++++++++-------- tests/test_simple.cc | 7 ++++--- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/onnxoptimizer/passes/split.h b/onnxoptimizer/passes/split.h index f705086b2..55eb0d391 100644 --- a/onnxoptimizer/passes/split.h +++ b/onnxoptimizer/passes/split.h @@ -28,26 +28,26 @@ static constexpr const char* impure_operators[] = { // may return an incorrect type, but that would have been present in the // original model's value_info if available. static int32_t inferElemType(const Value* v) { - if (v->elemType() != 0) { + if (v->elemType() != TensorProto_DataType_UNDEFINED) { return v->elemType(); } // Check if the value has a producing node const Node* producer = v->node(); if (!producer) { - return 0; + return TensorProto_DataType_UNDEFINED; } // For many operators, output type matches input type // Check if any input has a known elem_type for (const Value* input : producer->inputs()) { - if (input->elemType() != 0) { + if (input->elemType() != TensorProto_DataType_UNDEFINED) { return input->elemType(); } } // Couldn't infer - return UNDEFINED - return 0; + return TensorProto_DataType_UNDEFINED; } static bool is_pure_operator(Node* n) { @@ -157,9 +157,9 @@ static void split_init_and_predict(Graph& graph, bool init, bool predict) { continue; } // Ensure the value has elem_type set before registering as output - if (v->elemType() == 0) { + if (v->elemType() == TensorProto_DataType_UNDEFINED) { int32_t elem_type = inferElemType(v); - if (elem_type != 0) { + if (elem_type != TensorProto_DataType_UNDEFINED) { v->setElemType(elem_type); } } @@ -215,10 +215,10 @@ static void split_init_and_predict(Graph& graph, bool init, bool predict) { } // For elem_type, try to infer if not set int32_t elem_type = inferElemType(v); - if (elem_type != 0) { + if (elem_type != TensorProto_DataType_UNDEFINED) { newv->setElemType(elem_type); } - // Note: If elem_type is still UNDEFINED (0), the resulting model + // Note: If elem_type is still UNDEFINED, the resulting model // may be invalid. This indicates the input model lacks proper // type information for intermediate values. The model validation // will catch this error during onnx.checker.check_model(). diff --git a/tests/test_simple.cc b/tests/test_simple.cc index 4edc12519..7d36552c2 100644 --- a/tests/test_simple.cc +++ b/tests/test_simple.cc @@ -65,12 +65,13 @@ TEST(OptimizerTest, SplitPredictPreservesElemType) { FAIL() << "Optimized model failed validation: " << e.what(); } - // Verify all inputs have valid elem_type (not UNDEFINED which is 0) + // Verify all inputs have valid elem_type (not UNDEFINED) for (const auto& input : optimized_model.graph().input()) { ASSERT_TRUE(input.has_type()) << "Input " << input.name() << " missing type"; ASSERT_TRUE(input.type().has_tensor_type()) << "Input " << input.name() << " missing tensor_type"; - // TensorProto_DataType_UNDEFINED has value 0 - ASSERT_NE(input.type().tensor_type().elem_type(), 0) + // Check elem_type is not UNDEFINED + ASSERT_NE(input.type().tensor_type().elem_type(), + ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) << "Input " << input.name() << " has UNDEFINED elem_type"; } } From 0b34ad15fef7b769f861e7406775eeae74551f00 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 31 Dec 2025 21:20:29 +0000 Subject: [PATCH 8/8] Add null check for Value parameter in inferElemType - Add null pointer check for Value parameter to prevent segfaults - Clarify comment in Python test about TensorProto.UNDEFINED Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- onnxoptimizer/passes/split.h | 4 ++++ onnxoptimizer/test/optimizer_test.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxoptimizer/passes/split.h b/onnxoptimizer/passes/split.h index 55eb0d391..bbcf2bcde 100644 --- a/onnxoptimizer/passes/split.h +++ b/onnxoptimizer/passes/split.h @@ -28,6 +28,10 @@ static constexpr const char* impure_operators[] = { // may return an incorrect type, but that would have been present in the // original model's value_info if available. static int32_t inferElemType(const Value* v) { + if (!v) { + return TensorProto_DataType_UNDEFINED; + } + if (v->elemType() != TensorProto_DataType_UNDEFINED) { return v->elemType(); } diff --git a/onnxoptimizer/test/optimizer_test.py b/onnxoptimizer/test/optimizer_test.py index e20764292..d8ca6c058 100644 --- a/onnxoptimizer/test/optimizer_test.py +++ b/onnxoptimizer/test/optimizer_test.py @@ -4717,7 +4717,7 @@ def test_split_predict_preserves_elem_type(self): for input_val in optimized_model.graph.input: self.assertIsNotNone(input_val.type) self.assertTrue(input_val.type.HasField('tensor_type')) - # elem_type must be set (not TensorProto_DataType_UNDEFINED which is 0) + # elem_type must not be UNDEFINED (TensorProto.UNDEFINED = 0) self.assertNotEqual(input_val.type.tensor_type.elem_type, TensorProto.UNDEFINED, f"Input {input_val.name} has UNDEFINED elem_type")