diff --git a/onnxoptimizer/passes/split.h b/onnxoptimizer/passes/split.h index 5a7b97938..bbcf2bcde 100644 --- a/onnxoptimizer/passes/split.h +++ b/onnxoptimizer/passes/split.h @@ -21,6 +21,39 @@ static constexpr const char* impure_operators[] = { "Scan", }; +// Helper function to infer elem_type for a value if it's not set +// This is a simple heuristic that works for many common operators where +// output type matches input type (e.g., Add, Sub, Mul, etc.). +// For operators with different output types (e.g., Shape, Cast), this +// may return an incorrect type, but that would have been present in the +// original model's value_info if available. +static int32_t inferElemType(const Value* v) { + if (!v) { + return TensorProto_DataType_UNDEFINED; + } + + if (v->elemType() != TensorProto_DataType_UNDEFINED) { + return v->elemType(); + } + + // Check if the value has a producing node + const Node* producer = v->node(); + if (!producer) { + return TensorProto_DataType_UNDEFINED; + } + + // For many operators, output type matches input type + // Check if any input has a known elem_type + for (const Value* input : producer->inputs()) { + if (input->elemType() != TensorProto_DataType_UNDEFINED) { + return input->elemType(); + } + } + + // Couldn't infer - return UNDEFINED + return TensorProto_DataType_UNDEFINED; +} + static bool is_pure_operator(Node* n) { for (auto x : impure_operators) { if (n->kind() == Symbol(x)) { @@ -127,6 +160,13 @@ static void split_init_and_predict(Graph& graph, bool init, bool predict) { if (v->node()->kind() == kUndefined) { continue; } + // Ensure the value has elem_type set before registering as output + if (v->elemType() == TensorProto_DataType_UNDEFINED) { + int32_t elem_type = inferElemType(v); + if (elem_type != TensorProto_DataType_UNDEFINED) { + v->setElemType(elem_type); + } + } graph.registerOutput(v); } @@ -169,7 +209,23 @@ static void split_init_and_predict(Graph& graph, bool init, bool predict) { if (v->node()->kind() == kUndefined) { v->replaceAllUsesWith(optionalInputDummyNode->outputs()[0]); } else { - Value* newv = graph.addInput()->copyMetadata(v); + Value* newv = graph.addInput(); + // Copy sizes and name first + if (v->has_sizes()) { + newv->setSizes(v->sizes()); + } + if (v->has_unique_name()) { + newv->setUniqueName(v->uniqueName()); + } + // For elem_type, try to infer if not set + int32_t elem_type = inferElemType(v); + if (elem_type != TensorProto_DataType_UNDEFINED) { + newv->setElemType(elem_type); + } + // Note: If elem_type is still UNDEFINED, the resulting model + // may be invalid. This indicates the input model lacks proper + // type information for intermediate values. The model validation + // will catch this error during onnx.checker.check_model(). v->replaceAllUsesWith(newv); } } diff --git a/onnxoptimizer/test/optimizer_test.py b/onnxoptimizer/test/optimizer_test.py index 0f7a8b9cc..d8ca6c058 100644 --- a/onnxoptimizer/test/optimizer_test.py +++ b/onnxoptimizer/test/optimizer_test.py @@ -4676,6 +4676,52 @@ def test_rewrite_where(self): assert optimized_model.graph.node[0].input == ['A', 'Y', 'X'] assert optimized_model.graph.node[3].input == ['M', 'X', 'Y'] + def test_split_predict_preserves_elem_type(self): + """ + Test that split_predict pass preserves elem_type information + when creating new inputs/outputs. + + This reproduces the bug where intermediate values without explicit + value_info cause split_predict to generate invalid models with + missing elem_type. + """ + # Create a model where intermediate values don't have value_info + # but should have their types inferred from inputs + model = parser.parse_model(""" + < + ir_version: 7, + opset_import:["": 13] + > + agraph (float[2] X) => (float[2] Y) + { + # Pure operation that can go to init net + one = Constant() + added = Add(X, one) + + # Impure operation that must stay in predict net + # This uses 'added', making it a boundary value + random = RandomUniform() + Y = Add(random, added) + } + """) + + # Optimize with split_predict + optimized_model = self._optimized( + model, ['split_predict'], False, compare_result=False, check=True) + + # Verify the optimized model is valid + # The key check is that checker.check_model() passes, + # which is done by _optimized when check=True + + # Additionally verify all inputs have valid elem_type + for input_val in optimized_model.graph.input: + self.assertIsNotNone(input_val.type) + self.assertTrue(input_val.type.HasField('tensor_type')) + # elem_type must not be UNDEFINED (TensorProto.UNDEFINED = 0) + self.assertNotEqual(input_val.type.tensor_type.elem_type, + TensorProto.UNDEFINED, + f"Input {input_val.name} has UNDEFINED elem_type") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_simple.cc b/tests/test_simple.cc index 649497d30..7d36552c2 100644 --- a/tests/test_simple.cc +++ b/tests/test_simple.cc @@ -5,6 +5,7 @@ #include #include #include +#include TEST(OptimizerTest, NopReshape) { const char* graph_str = R"( @@ -27,3 +28,50 @@ TEST(OptimizerTest, NopReshape) { ASSERT_EQ(optimized_model.graph().node().size(), 1); ASSERT_EQ(optimized_model.graph().node()[0].op_type(), "Identity"); } + +TEST(OptimizerTest, SplitPredictPreservesElemType) { + // Test that split_predict preserves elem_type for intermediate values + // This reproduces the bug where intermediate values without value_info + // cause split_predict to generate invalid models with missing elem_type + const char* graph_str = R"( + < + ir_version: 7, + opset_import: [ "": 13] + > + agraph (float[2] X) => (float[2] Y) + { + # Pure operation that can go to init net + one = Constant() + added = Add(X, one) + + # Impure operation that must stay in predict net + # This uses 'added', making it a boundary value + random = RandomUniform() + Y = Add(random, added) + } + )"; + + onnx::ModelProto model; + const onnx::Status status = onnx::OnnxParser::Parse(model, graph_str); + EXPECT_TRUE(status.IsOK()); + + // Run split_predict optimization + auto optimized_model = onnx::optimization::Optimize(model, {"split_predict"}); + + // Verify the model is valid - this will catch missing elem_type + try { + onnx::checker::check_model(optimized_model); + } catch (const std::exception& e) { + FAIL() << "Optimized model failed validation: " << e.what(); + } + + // Verify all inputs have valid elem_type (not UNDEFINED) + for (const auto& input : optimized_model.graph().input()) { + ASSERT_TRUE(input.has_type()) << "Input " << input.name() << " missing type"; + ASSERT_TRUE(input.type().has_tensor_type()) << "Input " << input.name() << " missing tensor_type"; + // Check elem_type is not UNDEFINED + ASSERT_NE(input.type().tensor_type().elem_type(), + ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) + << "Input " << input.name() << " has UNDEFINED elem_type"; + } +}