Skip to content

Conversation

@P-Mahi10
Copy link

Fix Issue #2736

This change extends constant folding for Reshape to eliminate dynamic shape
subgraphs of the form Shape → Gather → Concat → Reshape when the target shape
contains exactly one dynamic dimension. The optimizer rewrites the reshape to
use a static -1 dimension, relying on ONNX Reshape semantics to infer the
dynamic value from the input element count. This removes runtime shape
evaluation and enables further canonicalization.

Key Changes

  • onnxscript/optimizer/_constant_folding.py
    Updated the reshape optimizer to detect reshape targets with exactly one
    dynamic dimension and replace the dynamic shape computation with a static
    shape containing -1.

  • onnxscript/optimizer/_constant_folding_test.py
    Added a unit test to verify that dynamic shape subgraphs (Shape, Gather,
    Concat) are eliminated and that the reshape is either rewritten with -1
    or folded away entirely.

@P-Mahi10
Copy link
Author

@microsoft-github-policy-service agree

for i, dim in enumerate(shape_value.dims):
if isinstance(dim, ir.SymbolicDim) and dim.value is None:
dynamic_indices.append(i)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

for i, dim in enumerate(shape_value.dims):
if isinstance(dim, ir.SymbolicDim) and dim.value is None:
dynamic_indices.append(i)

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
if not isinstance(dim, int):
all_others_static = False
break

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

if not isinstance(dim, int):
all_others_static = False
break

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
new_shape_list.append(-1)
else:
new_shape_list.append(dim)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

len(nodes), 0,
f"Expected no {op_type} nodes after optimization, found {len(nodes)}"
)

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
len(reshape_nodes) == 1 or len(identity_nodes) == 1,
f"Expected one Reshape or Identity node, found {len(reshape_nodes)} Reshape and {len(identity_nodes)} Identity"
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

len(reshape_nodes) == 1 or len(identity_nodes) == 1,
f"Expected one Reshape or Identity node, found {len(reshape_nodes)} Reshape and {len(identity_nodes)} Identity"
)

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
shape_input = reshape_node.inputs[1]
self.assertIsNotNone(shape_input, "Reshape shape input should not be None")
self.assertIsNotNone(shape_input.const_value, "Shape input should be a constant")

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

shape_input = reshape_node.inputs[1]
self.assertIsNotNone(shape_input, "Reshape shape input should not be None")
self.assertIsNotNone(shape_input.const_value, "Shape input should be a constant")

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
@codecov
Copy link

codecov bot commented Jan 15, 2026

Codecov Report

❌ Patch coverage is 31.57895% with 26 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.15%. Comparing base (3adee71) to head (f387138).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/optimizer/_constant_folding.py 14.28% 16 Missing and 2 partials ⚠️
onnxscript/optimizer/_constant_folding_test.py 52.94% 7 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2769      +/-   ##
==========================================
- Coverage   70.21%   70.15%   -0.06%     
==========================================
  Files         228      228              
  Lines       27316    27354      +38     
  Branches     2769     2780      +11     
==========================================
+ Hits        19179    19191      +12     
- Misses       7188     7211      +23     
- Partials      949      952       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Development

Successfully merging this pull request may close these issues.

2 participants