diff --git a/onnxoptimizer/onnxoptimizer_main.py b/onnxoptimizer/onnxoptimizer_main.py index 594ca69be..1236ae204 100644 --- a/onnxoptimizer/onnxoptimizer_main.py +++ b/onnxoptimizer/onnxoptimizer_main.py @@ -59,6 +59,12 @@ def main(): parser.add_argument( "--fixed_point", action="store_true", default=False, help="fixed point" ) + parser.add_argument( + "--skip_infer_shapes", + action="store_true", + default=False, + help="Skip shape inference after optimization", + ) argv = sys.argv.copy() args = parser.parse_args(format_argv(sys.argv)) @@ -98,6 +104,8 @@ def main(): if model is None: print("onnxoptimizer failed") sys.exit(1) + if not args.skip_infer_shapes: + model = onnx.shape_inference.infer_shapes(model) try: onnx.save(proto=model, f=output_file) except Exception: