-
Notifications
You must be signed in to change notification settings - Fork 104
Open
Description
Pass "fuse_bn_into_conv" causes output mismatch
Issue
Applying only the fuse_bn_into_conv pass with onnxoptimizer 0.3.19 changes numerical outputs. The original model is numerically stable with the provided oracle inputs, but the optimized graph deviates immediately after this single pass.
Environment
- Ubuntu 20.04
- Python 3.10
- onnx==1.19.0
- onnxruntime==1.23.2
- onnxoptimizer==0.3.19 (latest)
Repro steps (run from this folder)
- Download and unzip the attached archive below, then
cdinto the extracted directory
fuse_bn_into_conv_repro.tar.gz
tar -xzvf fuse_bn_into_conv_repro.tar.gz
cd fuse_bn_into_conv_repro
- Create a Python environment (Python 3.10) and install dependencies:
python3 -m venv .venv
source .venv/bin/activate
pip install -U pip
pip install -r requirements.txt
- Optimize the case with only
fuse_bn_into_conv(writesmodel.opt.onnxnext tomodel.onnx):
python optimize_model.py --case ./case_00057_seed20702402
- Differential test original vs optimized outputs using the stored oracle inputs:
python diff_test.py --case ./case_00057_seed20702402
Observed results
- case_00057_seed20702402:
overall max_abs=3.815e-06,max_rel=2.780e-07; detailed per-output diff below.
Expected
fuse_bn_into_conv should be semantics-preserving. Applying only this pass should not change any output values. Please investigate why the optimized graph diverges and whether the pass is incorrectly folding batch norm parameters into convolutions for this model.
Differential Test Output Details
- case_00057_seed20702402
Case: case_00057_seed20702402
output[0]: max_abs=9.537e-07, max_rel=9.442e-08, shape=(2, 1, 1, 1)
output[1]: max_abs=3.815e-06, max_rel=2.780e-07, shape=(2, 1, 4, 3)
output[2]: max_abs=7.153e-07, max_rel=2.719e-07, shape=(2, 1, 1, 1)
output[3]: max_abs=0.000e+00, max_rel=0.000e+00, shape=(2, 2, 2, 1, 1)
output[4]: max_abs=0.000e+00, max_rel=0.000e+00, shape=(2, 2, 2, 16)
output[5]: max_abs=0.000e+00, max_rel=0.000e+00, shape=(2, 2, 2, 2, 1)
output[6]: max_abs=0.000e+00, max_rel=0.000e+00, shape=(2, 2, 16)
output[7]: max_abs=0.000e+00, max_rel=0.000e+00, shape=(2, 2, 1)
output[8]: max_abs=0.000e+00, max_rel=0.000e+00, shape=(2, 16)
output[9]: max_abs=0.000e+00, max_rel=0.000e+00, shape=(1, 16)
output[10]: max_abs=0.000e+00, max_rel=0.000e+00, shape=(1,)
output[11]: max_abs=0.000e+00, max_rel=0.000e+00, shape=(1, 16)
output[12]: max_abs=0.000e+00, max_rel=0.000e+00, shape=()
output[13]: max_abs=0.000e+00, max_rel=0.000e+00, shape=(1, 16)
output[14]: max_abs=0.000e+00, max_rel=0.000e+00, shape=()
output[15]: max_abs=0.000e+00, max_rel=0.000e+00, shape=(1,)
output[16]: max_abs=0.000e+00, max_rel=0.000e+00, shape=(16, 44)
Overall: max_abs=3.815e-06, max_rel=2.780e-07
Attachments
README.md(this document)requirements.txt(dependency versions)optimize_model.py(runs onlyfuse_bn_into_convand savesmodel.opt.onnx)diff_test.py(runs original vs optimized with oracle inputs and reports max_abs/max_rel)case_00057_seed20702402/(containsmodel.onnxandoracle.pklused for both runs)
Metadata
Metadata
Assignees
Labels
No labels