Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion onnxoptimizer/passes/split.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,39 @@ static constexpr const char* impure_operators[] = {
"Scan",
};

// Helper function to infer elem_type for a value if it's not set
// This is a simple heuristic that works for many common operators where
// output type matches input type (e.g., Add, Sub, Mul, etc.).
// For operators with different output types (e.g., Shape, Cast), this
// may return an incorrect type, but that would have been present in the
// original model's value_info if available.
static int32_t inferElemType(const Value* v) {
if (!v) {
return TensorProto_DataType_UNDEFINED;
}

if (v->elemType() != TensorProto_DataType_UNDEFINED) {
return v->elemType();
}

// Check if the value has a producing node
const Node* producer = v->node();
if (!producer) {
return TensorProto_DataType_UNDEFINED;
}

// For many operators, output type matches input type
// Check if any input has a known elem_type
for (const Value* input : producer->inputs()) {
if (input->elemType() != TensorProto_DataType_UNDEFINED) {
return input->elemType();
}
}

// Couldn't infer - return UNDEFINED
return TensorProto_DataType_UNDEFINED;
}

static bool is_pure_operator(Node* n) {
for (auto x : impure_operators) {
if (n->kind() == Symbol(x)) {
Expand Down Expand Up @@ -127,6 +160,13 @@ static void split_init_and_predict(Graph& graph, bool init, bool predict) {
if (v->node()->kind() == kUndefined) {
continue;
}
// Ensure the value has elem_type set before registering as output
if (v->elemType() == TensorProto_DataType_UNDEFINED) {
int32_t elem_type = inferElemType(v);
if (elem_type != TensorProto_DataType_UNDEFINED) {
v->setElemType(elem_type);
}
}
graph.registerOutput(v);
}

Expand Down Expand Up @@ -169,7 +209,23 @@ static void split_init_and_predict(Graph& graph, bool init, bool predict) {
if (v->node()->kind() == kUndefined) {
v->replaceAllUsesWith(optionalInputDummyNode->outputs()[0]);
} else {
Value* newv = graph.addInput()->copyMetadata(v);
Value* newv = graph.addInput();
// Copy sizes and name first
if (v->has_sizes()) {
newv->setSizes(v->sizes());
}
if (v->has_unique_name()) {
newv->setUniqueName(v->uniqueName());
}
// For elem_type, try to infer if not set
int32_t elem_type = inferElemType(v);
if (elem_type != TensorProto_DataType_UNDEFINED) {
newv->setElemType(elem_type);
}
// Note: If elem_type is still UNDEFINED, the resulting model
// may be invalid. This indicates the input model lacks proper
// type information for intermediate values. The model validation
// will catch this error during onnx.checker.check_model().
v->replaceAllUsesWith(newv);
}
}
Expand Down
46 changes: 46 additions & 0 deletions onnxoptimizer/test/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4676,6 +4676,52 @@ def test_rewrite_where(self):
assert optimized_model.graph.node[0].input == ['A', 'Y', 'X']
assert optimized_model.graph.node[3].input == ['M', 'X', 'Y']

def test_split_predict_preserves_elem_type(self):
"""
Test that split_predict pass preserves elem_type information
when creating new inputs/outputs.

This reproduces the bug where intermediate values without explicit
value_info cause split_predict to generate invalid models with
missing elem_type.
"""
# Create a model where intermediate values don't have value_info
# but should have their types inferred from inputs
model = parser.parse_model("""
<
ir_version: 7,
opset_import:["": 13]
>
agraph (float[2] X) => (float[2] Y)
{
# Pure operation that can go to init net
one = Constant<value: tensor = float[2] {1.0, 1.0}>()
added = Add(X, one)

# Impure operation that must stay in predict net
# This uses 'added', making it a boundary value
random = RandomUniform<dtype: int = 1, shape: ints = [2]>()
Y = Add(random, added)
}
""")

# Optimize with split_predict
optimized_model = self._optimized(
model, ['split_predict'], False, compare_result=False, check=True)

# Verify the optimized model is valid
# The key check is that checker.check_model() passes,
# which is done by _optimized when check=True

# Additionally verify all inputs have valid elem_type
for input_val in optimized_model.graph.input:
self.assertIsNotNone(input_val.type)
self.assertTrue(input_val.type.HasField('tensor_type'))
# elem_type must not be UNDEFINED (TensorProto.UNDEFINED = 0)
self.assertNotEqual(input_val.type.tensor_type.elem_type,
TensorProto.UNDEFINED,
f"Input {input_val.name} has UNDEFINED elem_type")


if __name__ == "__main__":
unittest.main()
48 changes: 48 additions & 0 deletions tests/test_simple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <gtest/gtest.h>
#include <onnxoptimizer/optimize.h>
#include <onnx/defs/parser.h>
#include <onnx/checker.h>

TEST(OptimizerTest, NopReshape) {
const char* graph_str = R"(
Expand All @@ -27,3 +28,50 @@ TEST(OptimizerTest, NopReshape) {
ASSERT_EQ(optimized_model.graph().node().size(), 1);
ASSERT_EQ(optimized_model.graph().node()[0].op_type(), "Identity");
}

TEST(OptimizerTest, SplitPredictPreservesElemType) {
// Test that split_predict preserves elem_type for intermediate values
// This reproduces the bug where intermediate values without value_info
// cause split_predict to generate invalid models with missing elem_type
const char* graph_str = R"(
<
ir_version: 7,
opset_import: [ "": 13]
>
agraph (float[2] X) => (float[2] Y)
{
# Pure operation that can go to init net
one = Constant<value: tensor = float[2] {1.0, 1.0}>()
added = Add(X, one)

# Impure operation that must stay in predict net
# This uses 'added', making it a boundary value
random = RandomUniform<dtype: int = 1, shape: ints = [2]>()
Y = Add(random, added)
}
)";

onnx::ModelProto model;
const onnx::Status status = onnx::OnnxParser::Parse(model, graph_str);
EXPECT_TRUE(status.IsOK());

// Run split_predict optimization
auto optimized_model = onnx::optimization::Optimize(model, {"split_predict"});

// Verify the model is valid - this will catch missing elem_type
try {
onnx::checker::check_model(optimized_model);
} catch (const std::exception& e) {
FAIL() << "Optimized model failed validation: " << e.what();
}

// Verify all inputs have valid elem_type (not UNDEFINED)
for (const auto& input : optimized_model.graph().input()) {
ASSERT_TRUE(input.has_type()) << "Input " << input.name() << " missing type";
ASSERT_TRUE(input.type().has_tensor_type()) << "Input " << input.name() << " missing tensor_type";
// Check elem_type is not UNDEFINED
ASSERT_NE(input.type().tensor_type().elem_type(),
ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED)
<< "Input " << input.name() << " has UNDEFINED elem_type";
}
}