From d4b7d60af8a864de33ff7a63a9fdfe2d28e4200f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 Jan 2026 15:40:28 -0800 Subject: [PATCH 01/10] Replace op_schema with op_signature Signed-off-by: Justin Chu --- onnxscript/_internal/autocast.py | 27 +++++++++++++------------- onnxscript/_internal/converter.py | 6 +++--- onnxscript/_internal/evaluator.py | 32 ++++++++++++------------------- 3 files changed, 29 insertions(+), 36 deletions(-) diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index bc3e16f79e..59732d0c08 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -8,9 +8,9 @@ import numpy as np import onnx import onnx.helper # noqa: TID251 -from onnx.defs import OpSchema from onnxscript import ir, tensor +from onnxscript.ir import _schemas if TYPE_CHECKING: from onnxscript._internal import converter @@ -126,7 +126,7 @@ def cast_pyvalue_to_os_tensor(pyvalue, dtype=None): def cast_inputs( get_type_info: Callable[[Any], Any], cast: Callable[[Any, Any], Any], - op_schema: OpSchema | None, + op_signature: _schemas.OpSignature | None, args, ) -> tuple[Any, ...]: """Uses schema specification to support a limited form of auto-casting. @@ -140,12 +140,15 @@ def cast_inputs( This is used by the converter in a static-mode, as well as by the eager-mode execution in a dynamic-mode. """ - if op_schema is None: + if op_signature is None: # Either an error or a custom op. # No checks/casts in this case. return tuple(cast(x, None) for x in args) - expected_inputs = op_schema.inputs + # Filter to get only input parameters (not AttributeParameters) + expected_inputs = [ + param for param in op_signature.params if isinstance(param, _schemas.Parameter) + ] # We make two passes. In the first pass, we identify known type-bindings for # type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}. # In the second pass, we use these bindings to cast scalar-values to @@ -156,17 +159,15 @@ def cast_inputs( for i, x in enumerate(args): if i < len(expected_inputs): expected = expected_inputs[i] - elif expected_inputs[-1].option == OpSchema.FormalParameterOption.Variadic: + elif expected_inputs[-1].variadic: expected = expected_inputs[-1] - if not expected.is_homogeneous: - args_typevars.append((x, None)) - continue + # TODO(justinchuby): Handle is_homogeneous params else: raise ValueError( f"Number of actual parameters {len(args)} " f"exceeds number of formal parameters {len(expected_inputs)}." ) - typevar = expected.type_str + typevar = expected.type_constraint.name if "(" not in typevar: # typevar is an identifier, like "T" typeinfo = get_type_info(x) @@ -177,18 +178,18 @@ def cast_inputs( return tuple(cast_args) -def dynamic_cast_inputs(op_schema: OpSchema, args): +def dynamic_cast_inputs(op_signature: _schemas.OpSignature, args): """Used for autocast during eager-mode execution.""" def get_type_info(x): return x.dtype if isinstance(x, tensor.Tensor) else None - return cast_inputs(get_type_info, cast_pyvalue_to_os_tensor, op_schema, args) + return cast_inputs(get_type_info, cast_pyvalue_to_os_tensor, op_signature, args) def static_cast_inputs( converter_: converter.Converter, - op_schema: Optional[OpSchema], + op_signature: Optional[_schemas.OpSignature], args: Sequence[Optional[ir.Value]], ) -> tuple[str, ...]: """Used for autocast during script-translation. @@ -212,4 +213,4 @@ def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]: return converter_.emit1([x_cast], "CastLike", [x, y]) return x - return cast_inputs(get_type_info, cast_like, op_schema, args) + return cast_inputs(get_type_info, cast_like, op_signature, args) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index dd902ac7ab..6ab228ef4d 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -887,7 +887,7 @@ def _translate_call_expr( else: args = [self._translate_opt_expr(x) for x in node.args] attrs = [self._translate_attr(x.arg, x.value) for x in node.keywords] - args = autocast.static_cast_inputs(self, callee.op_schema, args) + args = autocast.static_cast_inputs(self, callee.op_signature, args) # In ONNX, there is no way to explicitly specify a None value for an attribute. # Instead, the attribute must be omitted from the attribute list. @@ -896,8 +896,8 @@ def _translate_call_expr( return callee, args, attrs def _cast_like_binary_expression(self, op, left, right) -> tuple[ir.Value, ir.Value]: - schema = op.op_schema - return autocast.static_cast_inputs(self, schema, (left, right)) + op_signature = op.op_signature + return autocast.static_cast_inputs(self, op_signature, (left, right)) def _translate_binary_op_expr(self, node: ast.BinOp): op = type(node.op) diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index 1415733397..5eefe3d963 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -189,38 +189,35 @@ def eval( attributes: The ONNX attributes to the op. """ attributes = _unwrap_tensors_in_kwargs(attributes) - attributes, closure = self.adapt_attributes(schema, attributes) - inputs = self.adapt_inputs(schema, inputs) + attributes, closure = self._adapt_attributes(attributes) + inputs = self._adapt_inputs(schema, inputs) outputs = self._eval(schema, inputs, attributes, closure) - return self.adapt_outputs(schema, outputs) + return self._adapt_outputs(outputs) - def adapt_inputs(self, schema: onnx.defs.OpSchema, inputs: Sequence[ExtendedModeValue]): + def _adapt_inputs(self, schema: onnx.defs.OpSchema, inputs: Sequence[ExtendedModeValue]): """Transform inputs to the expected format for the evaluator. Enables some syntactic sugar, such as the use of Python scalars, in a manner consistent with the translator. See autocast.py for details. """ - return autocast.dynamic_cast_inputs(schema, inputs) + op_signature = _schemas.OpSignature.from_op_schema(schema) + return autocast.dynamic_cast_inputs(op_signature, inputs) - def adapt_attributes( - self, schema: onnx.defs.OpSchema, attributes: Mapping[str, ExtendedModeValue] + def _adapt_attributes( + self, attributes: Mapping[str, ExtendedModeValue] ) -> tuple[dict[str, ExtendedModeValue], dict[str, ExtendedModeValue]]: """Transform attributes to the expected format for the evaluator. Returns: A closure that can be used to evaluate graph-valued attributes. """ - use_graph_attribute = self.use_graph_attribute(schema) closure: dict[Any, Any] = {} adapted_attributes = {} for k, v in attributes.items(): if isinstance(v, values.OnnxClosure): - if use_graph_attribute: - adapted_attributes[k] = v.function_ir.to_graph_proto() - for pyvar, onnxvar in v.function_ir.outer_scope_variables: - closure[onnxvar.value.name] = v.frame.f_locals[pyvar] - else: - adapted_attributes[k] = v.function + adapted_attributes[k] = v.function_ir.to_graph_proto() + for pyvar, onnxvar in v.function_ir.outer_scope_variables: + closure[onnxvar.value.name] = v.frame.f_locals[pyvar] elif callable(v): raise TypeError( f"Error: function-valued attribute {v.__name__} has no graph_proto" @@ -230,18 +227,13 @@ def adapt_attributes( adapted_attributes[k] = v return adapted_attributes, closure - def adapt_outputs(self, schema: onnx.defs.OpSchema, outputs: Sequence[EagerModeValue]): + def _adapt_outputs(self, outputs: Sequence[EagerModeValue]): """Adapt evaluator's output to convention used in onnxscript. Onnxscript uses a tuple/sequence only when number of outputs > 1. """ - del schema # unused return outputs[0] if len(outputs) == 1 else outputs - def use_graph_attribute(self, schema: onnx.defs.OpSchema): - del schema # unused - return True - @abc.abstractmethod def _eval( self, From 78a25d816aeb835ad511e5f9913902940e733140 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 Jan 2026 20:47:11 -0800 Subject: [PATCH 02/10] wip Signed-off-by: Justin Chu --- onnxscript/_internal/evaluator.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index 5eefe3d963..edba055236 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -188,29 +188,32 @@ def eval( inputs: The ONNX inputs to the op. attributes: The ONNX attributes to the op. """ + op_signature = _schemas.OpSignature.from_op_schema(schema) attributes = _unwrap_tensors_in_kwargs(attributes) - attributes, closure = self._adapt_attributes(attributes) - inputs = self._adapt_inputs(schema, inputs) + attributes, closure = self._adapt_attributes(op_signature, attributes) + inputs = self._adapt_inputs(op_signature, inputs) outputs = self._eval(schema, inputs, attributes, closure) return self._adapt_outputs(outputs) - def _adapt_inputs(self, schema: onnx.defs.OpSchema, inputs: Sequence[ExtendedModeValue]): + def _adapt_inputs( + self, op_signature: _schemas.OpSignature, inputs: Sequence[ExtendedModeValue] + ): """Transform inputs to the expected format for the evaluator. Enables some syntactic sugar, such as the use of Python scalars, in a manner consistent with the translator. See autocast.py for details. """ - op_signature = _schemas.OpSignature.from_op_schema(schema) return autocast.dynamic_cast_inputs(op_signature, inputs) def _adapt_attributes( - self, attributes: Mapping[str, ExtendedModeValue] + self, op_signature, attributes: Mapping[str, ExtendedModeValue] ) -> tuple[dict[str, ExtendedModeValue], dict[str, ExtendedModeValue]]: """Transform attributes to the expected format for the evaluator. Returns: A closure that can be used to evaluate graph-valued attributes. """ + use_graph_attribute = self.use_graph_attribute(op_singature) closure: dict[Any, Any] = {} adapted_attributes = {} for k, v in attributes.items(): @@ -234,6 +237,12 @@ def _adapt_outputs(self, outputs: Sequence[EagerModeValue]): """ return outputs[0] if len(outputs) == 1 else outputs + + def use_graph_attribute(self, op_signature: _schemas.OpSignature) -> bool: + del schema # unused + return True + + @abc.abstractmethod def _eval( self, From f4e47bbf2120abe5b3c69c7894318a929a831345 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 11:03:00 -0800 Subject: [PATCH 03/10] Clean up _to_model_proto Signed-off-by: Justin Chu --- onnxscript/_internal/evaluator.py | 49 ++++---- onnxscript/_internal/irbuilder.py | 8 +- onnxscript/_internal/values.py | 185 +++--------------------------- onnxscript/tensor.py | 2 + 4 files changed, 53 insertions(+), 191 deletions(-) diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index edba055236..9b5cb2eabe 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -138,6 +138,21 @@ def eval( inputs: The ONNX inputs to the op. attributes: The ONNX attributes to the op. """ + # Deprecated. Implement eval_op instead + + def eval_op( + self, + op: values.Op, + args: Sequence[ExtendedModeValue], + kwargs: Mapping[str, ExtendedModeValue], + ): + """Evaluates an Op. + + Args: + op: The Op to evaluate. + args: The positional arguments to the op. + kwargs: The keyword arguments to the op. + """ def eval_function( self, @@ -175,26 +190,6 @@ def __init__(self, ignore_unknown_function_kwargs: bool = False): """ self._ignore_unknown_function_kwargs = ignore_unknown_function_kwargs - def eval( - self, - schema: onnx.defs.OpSchema, - inputs: Sequence[ExtendedModeValue], - attributes: Mapping[str, Any], - ): - """Evaluates an ONNX op. - - Args: - schema: The OpSchema of the operator to evaluate. - inputs: The ONNX inputs to the op. - attributes: The ONNX attributes to the op. - """ - op_signature = _schemas.OpSignature.from_op_schema(schema) - attributes = _unwrap_tensors_in_kwargs(attributes) - attributes, closure = self._adapt_attributes(op_signature, attributes) - inputs = self._adapt_inputs(op_signature, inputs) - outputs = self._eval(schema, inputs, attributes, closure) - return self._adapt_outputs(outputs) - def _adapt_inputs( self, op_signature: _schemas.OpSignature, inputs: Sequence[ExtendedModeValue] ): @@ -260,6 +255,20 @@ def _eval( closure: The closure to use when evaluating graph-valued attributes. """ + def eval_op( + self, + op: values.Op, + args: Sequence[ExtendedModeValue], + kwargs: Mapping[str, ExtendedModeValue], + ): + op_signature = op.op_signature + assert op_signature is not None, f"Op {op.name} has no signature." + attributes = _unwrap_tensors_in_kwargs(kwargs) + attributes, closure = self._adapt_attributes(op_signature, attributes) + inputs = self._adapt_inputs(op_signature, args) + outputs = self._eval(schema, inputs, attributes, closure) + return self._adapt_outputs(outputs) + def eval_function( self, function: values.OnnxFunction, diff --git a/onnxscript/_internal/irbuilder.py b/onnxscript/_internal/irbuilder.py index e5fa80622e..f287b6b1ab 100644 --- a/onnxscript/_internal/irbuilder.py +++ b/onnxscript/_internal/irbuilder.py @@ -77,7 +77,7 @@ def append_parameter(self, parameter: ir.Value | ir.Attr) -> None: def add_nested_function(self, fun: IRFunction) -> None: self.nested_functions[fun.name] = fun - def get_called_functions(self) -> dict[str, onnx.FunctionProto]: + def get_called_functions(self) -> dict[str, ir.Function]: called_functions: dict[str, values.OnnxFunction] = {} def visit(function_ir: IRFunction): @@ -94,12 +94,12 @@ def add(f: values.OnnxFunction): visit(self) - return {name: f.to_function_proto() for name, f in called_functions.items()} + return {name: f.function_ir for name, f in called_functions.items()} def to_graph_proto(self) -> onnx.GraphProto: """Converts this instance into a `onnx.GraphProto`.""" - return ir.to_proto(self.graph) + return ir.serde.serialize_graph(self.graph) def to_function_proto(self) -> onnx.FunctionProto: """Converts this instance into a `onnx.FunctionProto`.""" - return ir.to_proto(self) + return ir.serde.serialize_function(self) diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 2f22e1eefa..d87e661f46 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -4,6 +4,7 @@ # ruff: noqa: TID251 from __future__ import annotations +from collections.abc import Collection import dataclasses import functools @@ -168,9 +169,6 @@ def name(self) -> str: ... @property def opset(self) -> Opset: ... - @property - def op_schema(self) -> Optional[onnx.defs.OpSchema]: ... - @property def op_signature(self) -> Optional[_schemas.OpSignature]: ... @@ -203,15 +201,16 @@ def __init__( ) def __call__(self, *args, **kwargs): - # FIXME(after #225): Move import to the top of the file. from onnxscript._internal import evaluator # pylint: disable=import-outside-toplevel - schema = self.op_schema - if schema is None: - raise RuntimeError( - f"Op '{self.name}' does not have an OpSchema and cannot be evaluated." - ) - return evaluator.default().eval(schema, args, kwargs) + default_evaluator = evaluator.default() + if hasattr(default_evaluator, "eval"): + # Interface prior to onnxscript 0.6, used by PyTorch 2.10 and older + if self.op_schema is None: + raise ValueError(f"OpSchema not found for op '{self.name}'.") + return default_evaluator.eval(self.op_schema, args, kwargs) + # Use the new interface + return evaluator.default().eval_op(self, args, kwargs) @property def name(self) -> str: @@ -225,10 +224,6 @@ def opset(self) -> Opset: def op_schema(self) -> Optional[onnx.defs.OpSchema]: return self._op_schema - def has_schema(self) -> bool: - """Returns True if this op has an OpSchema.""" - return self.op_schema is not None - @property def op_signature(self) -> Optional[_schemas.OpSignature]: """Returns the signature of this op.""" @@ -261,99 +256,6 @@ class OnnxClosure: function: Any -@dataclasses.dataclass -class TypeConstraint: - """Represents a type constraint for an ONNX op. - - Attributes: - name: The name of the type constraint. - allowed_types: The allowed types for the type constraint. - """ - - name: str - allowed_types: list[str] - description: str = "" - - def as_tuple(self) -> tuple[str, list[str], str]: - """Returns the type constraint as a tuple.""" - return (self.name, self.allowed_types, self.description) - - -def _op_schema_from_function_ir( - function_ir: irbuilder.IRFunction, opset: Opset -) -> onnx.defs.OpSchema: - """Construct an ONNX OpSchema from an IRFunction.""" - - # Find all distinct types in the inputs and outputs - distinct_types = {_typeinfo(arg) for arg in function_ir.inputs}.union( - {_typeinfo(arg) for arg in function_ir.outputs} - ) - # Create a mapping from type to a unique name - type_to_constraint = {} - for i, type_ in enumerate(distinct_types): - name = f"T{i}" - type_to_constraint[type_] = TypeConstraint( - name=type_annotation.get_type_constraint_name(type_) or name, - allowed_types=type_annotation.pytype_to_type_strings(type_), - ) - - formal_inputs = [ - onnx.defs.OpSchema.FormalParameter( - arg.name, - type_to_constraint[_typeinfo(arg)].name, - param_option=( - onnx.defs.OpSchema.FormalParameterOption.Optional - if type_annotation.is_optional(_typeinfo(arg)) - else onnx.defs.OpSchema.FormalParameterOption.Single - ), - # TODO(justinchu): Check this is_homogeneous thing - is_homogeneous=True, - ) - for arg in function_ir.inputs - ] - formal_outputs = [ - onnx.defs.OpSchema.FormalParameter( - arg.name, - type_to_constraint[_typeinfo(arg)].name, - param_option=( - onnx.defs.OpSchema.FormalParameterOption.Optional - if type_annotation.is_optional(_typeinfo(arg)) - else onnx.defs.OpSchema.FormalParameterOption.Single - ), - # TODO(justinchu): Check this is_homogeneous thing - is_homogeneous=True, - ) - for arg in function_ir.outputs - ] - return onnx.defs.OpSchema( - function_ir.name, - opset.domain, - since_version=opset.version, - doc=function_ir.doc_string or "", - inputs=formal_inputs, - outputs=formal_outputs, - type_constraints=[constraint.as_tuple() for constraint in type_to_constraint.values()], - attributes=[ - *[ - onnx.defs.OpSchema.Attribute( - attr.name, - type=onnx.defs.OpSchema.AttrType(attr.type), # type: ignore[call-arg] - ) - for attr in function_ir.attrs - if attr.value is None - ], - *[ - onnx.defs.OpSchema.Attribute( - attr.name, - default_value=ir.to_proto(attr), - ) - for attr in function_ir.attrs - if attr.value is not None - ], - ], - ) - - class OnnxFunction(Op, Generic[_P, _R]): """Represents an ONNX op for which a function-body has been defined in onnxscript. @@ -399,27 +301,6 @@ def __init__( # Experimental fields self.traceable = False - @property - @deprecation.deprecated( - since="0.1", - removed_in="the future", - instructions="use '.name' instead", - ) - def opname(self) -> str: - # NOTE: This is a temporary alias for backward compatibility with PyTorch 2.0. - # TODO: Remove this in onnxscript 0.3. - return self.name - - @property - def op_schema(self) -> Optional[onnx.defs.OpSchema]: - """Construct an OpSchema from function_ir.""" - if self._op_schema is not None: - return self._op_schema - - self._op_schema = _op_schema_from_function_ir(self.function_ir, self.opset) - - return self._op_schema - @property def op_signature(self) -> Optional[_schemas.OpSignature]: """Returns the signature of this op.""" @@ -438,28 +319,8 @@ def op_signature(self) -> Optional[_schemas.OpSignature]: def op_signature(self, value: _schemas.OpSignature): self._signature = value - def __getitem__(self, instance): - """Returns a lambda to evaluate function using given evaluator instance. - - Usage: - script_fun(X) executes the function using the default evaluator instance. - script_fun[instance](X) executes the function using the given evaluator instance. - """ - - def fun(*args, **kwargs): - # FIXME(after #225): Move import to the top of the file. - from onnxscript._internal import ( # pylint: disable=import-outside-toplevel - evaluator, - ) - - with evaluator.default_as(instance): - return self.__call__(*args, **kwargs) - - return fun - def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: """Implements an eager-mode execution of an onnxscript function.""" - # FIXME(after #225): Move import to the top of the file. from onnxscript._internal import evaluator # pylint: disable=import-outside-toplevel return evaluator.default().eval_function(self, args, kwargs) # type: ignore[arg-type, return-value] @@ -490,7 +351,7 @@ def to_model_proto(self, **kwargs): def _to_model_proto( self, - functions=None, + functions: Collection[ir.Function] | None = None, io_types: Optional[ONNXType] = None, input_types: Optional[Sequence[ONNXType]] = None, output_types: Optional[Sequence[ONNXType]] = None, @@ -522,27 +383,15 @@ def _to_model_proto( if functions is None: sub_functions = self.function_ir.get_called_functions() functions = sub_functions.values() - else: - - def to_proto(f): - if isinstance(f, onnx.FunctionProto): - return f - if isinstance(f, OnnxFunction): - return f.to_function_proto() - raise TypeError("Expected a value of type FunctionProto of OnnxFunction") - - functions = [to_proto(f) for f in functions] # Determine opset imports opsets = self.function_ir.graph.opset_imports - for proto in functions: - if proto.domain not in opsets: - opsets[proto.domain] = 1 - # TODO(rama): Handle conflicts with appropriate error/warning message. - for opset in proto.opset_import: - if opset.domain not in opsets: - opsets[opset.domain] = opset.version + for func in functions: + if func.domain not in opsets: + opsets[func.domain] = 1 + + # No need to collect opsets from functions if "" not in opsets: # No operator is using the standard opset. @@ -559,8 +408,10 @@ def to_proto(f): # Create the model model = ir.Model(self.function_ir.graph, ir_version=ir_version) + for func in functions: + model.functions[func.identifier()] = func + model_proto = ir.to_proto(model) - model_proto.functions.extend(functions) # Set additional type information if provided graph = model_proto.graph diff --git a/onnxscript/tensor.py b/onnxscript/tensor.py index f1d781b808..6ad8f6bf12 100644 --- a/onnxscript/tensor.py +++ b/onnxscript/tensor.py @@ -16,6 +16,8 @@ class Tensor: Serves to define overloaded ops with an ONNX/ONNXScript semantics. """ + # TODO(justinchuby): Remove the tensor class and use ir.Value instead + def __init__(self, nparray: Optional[np.ndarray], opset=None): if nparray is not None and not isinstance(nparray, np.ndarray): raise TypeError( From be172c16d58a05bb76d09246b681e3f270e2ee54 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 11:22:59 -0800 Subject: [PATCH 04/10] wip Signed-off-by: Justin Chu --- onnxscript/_internal/evaluator.py | 33 ++++++++++++++++++------------- onnxscript/_internal/values.py | 8 ++++---- onnxscript/ir/_schemas.py | 5 ++++- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index 9b5cb2eabe..fc12ce725c 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -208,14 +208,17 @@ def _adapt_attributes( Returns: A closure that can be used to evaluate graph-valued attributes. """ - use_graph_attribute = self.use_graph_attribute(op_singature) + use_graph_attribute = self.use_graph_attribute(op_signature) closure: dict[Any, Any] = {} adapted_attributes = {} for k, v in attributes.items(): if isinstance(v, values.OnnxClosure): - adapted_attributes[k] = v.function_ir.to_graph_proto() - for pyvar, onnxvar in v.function_ir.outer_scope_variables: - closure[onnxvar.value.name] = v.frame.f_locals[pyvar] + if use_graph_attribute: + adapted_attributes[k] = v.function_ir.to_graph_proto() + for pyvar, onnxvar in v.function_ir.outer_scope_variables: + closure[onnxvar.value.name] = v.frame.f_locals[pyvar] + else: + adapted_attributes[k] = v.function elif callable(v): raise TypeError( f"Error: function-valued attribute {v.__name__} has no graph_proto" @@ -234,7 +237,7 @@ def _adapt_outputs(self, outputs: Sequence[EagerModeValue]): def use_graph_attribute(self, op_signature: _schemas.OpSignature) -> bool: - del schema # unused + del op_signature # unused return True @@ -266,7 +269,7 @@ def eval_op( attributes = _unwrap_tensors_in_kwargs(kwargs) attributes, closure = self._adapt_attributes(op_signature, attributes) inputs = self._adapt_inputs(op_signature, args) - outputs = self._eval(schema, inputs, attributes, closure) + outputs = self._eval(op.op_schema, inputs, attributes, closure) return self._adapt_outputs(outputs) def eval_function( @@ -285,6 +288,8 @@ def eval_function( kwargs: The keyword arguments to the function. """ op_signature = function.op_signature + if op_signature is None: + raise RuntimeError(f"Function {function.name} has no signature.") # Split happens in the evaluator instead of the OnnxFunction __call__ method # so that evaluators can control behaviors like whether to fill in default values for attributes. tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_signature( @@ -514,7 +519,7 @@ def _call_ort( return [_numpy_to_onnxscript_value(x) for x in result] -def _schema_id(schema: onnx.defs.OpSchema) -> tuple[str, str, int]: +def _op_identifier(schema) -> tuple[str, str, int]: return schema.name, schema.domain, schema.since_version @@ -562,13 +567,13 @@ def __init__(self) -> None: super().__init__() self._python_ops: dict[tuple[str, str, int], Any] = {} - def use_graph_attribute(self, schema: onnx.defs.OpSchema) -> bool: - return _schema_id(schema) not in self._python_ops + def use_graph_attribute(self, op_signature: _schemas.OpSignature) -> bool: + return _op_identifier(op_signature) not in self._python_ops def _eval(self, schema, inputs, attributes, closure): - schemaid = _schema_id(schema) - if schemaid in self._python_ops: - return self._python_ops[schemaid](inputs, attributes) + identifier = _op_identifier(schema) + if identifier in self._python_ops: + return self._python_ops[identifier](inputs, attributes) else: return super()._eval(schema, inputs, attributes, closure) @@ -576,8 +581,8 @@ def register(self, opset: values.Opset) -> Callable[[_T], _T]: assert opset is not None def decorator(function: _T) -> _T: - schema = opset[function.__name__] - self._python_ops[_schema_id(schema)] = function + op_signature = opset[function.__name__].op_signature + self._python_ops[_op_identifier(op_signature)] = function return function return decorator diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index d87e661f46..70f35d1c4f 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -4,7 +4,6 @@ # ruff: noqa: TID251 from __future__ import annotations -from collections.abc import Collection import dataclasses import functools @@ -12,7 +11,8 @@ import logging import types import typing -from typing import ( # type: ignore[attr-defined] +from collections.abc import Collection +from typing import ( Any, Callable, ClassVar, @@ -28,7 +28,7 @@ import onnx_ir as ir from typing_extensions import ParamSpec -from onnxscript._internal import ast_utils, deprecation, irbuilder, sourceinfo, type_annotation +from onnxscript._internal import ast_utils, irbuilder, sourceinfo from onnxscript._internal import converter as converter_module from onnxscript.ir import _schemas from onnxscript.onnx_types import ONNXType @@ -123,7 +123,7 @@ def __contains__(self, opname): def __str__(self) -> str: return self.domain - def __getattr__(self, attr: str): + def __getattr__(self, attr: str) -> Op: try: schema = onnx.defs.get_schema(attr, self.version, self.domain) return Op(self, attr, schema) diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index d4d88ab5bb..66f6875eb2 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -339,6 +339,7 @@ class OpSignature: params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field( init=False, repr=False ) + since_version: int = 1 def __post_init__(self): self.params_map = {param.name: param for param in self.params} @@ -415,11 +416,12 @@ def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: overload="", params=params, outputs=outputs, + since_version=op_schema.since_version, ) @classmethod def from_function( - cls, func, domain: str, name: str | None = None, overload: str = "" + cls, func, domain: str, name: str | None = None, overload: str = "", since_version: int = 1 ) -> OpSignature: """Produce an OpSignature from a function using type annotation.""" @@ -545,4 +547,5 @@ def from_function( overload=overload, params=params, outputs=outputs, + since_version=since_version, ) From 38ecac75bd1929825c349cfa065bbc5d1582ff1c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 11:28:00 -0800 Subject: [PATCH 05/10] update Signed-off-by: Justin Chu --- onnxscript/_internal/evaluator.py | 2 -- onnxscript/_internal/values.py | 23 +++++++++-------------- onnxscript/ir/_schemas.py | 7 ++++++- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index fc12ce725c..d7e9d5b775 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -235,12 +235,10 @@ def _adapt_outputs(self, outputs: Sequence[EagerModeValue]): """ return outputs[0] if len(outputs) == 1 else outputs - def use_graph_attribute(self, op_signature: _schemas.OpSignature) -> bool: del op_signature # unused return True - @abc.abstractmethod def _eval( self, diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 70f35d1c4f..015fa4875f 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -108,7 +108,8 @@ def __repr__(self): def __getitem__(self, opname): try: - return onnx.defs.get_schema(opname, self.version, self.domain) + schema = onnx.defs.get_schema(opname, self.version, self.domain) + return Op(self, opname, schema) except Exception: # pylint: disable=broad-except # TODO: more specific exception return None @@ -189,7 +190,13 @@ def __init__( ) -> None: self._opset = opset self._name = name - self._op_schema = op_schema or opset[name] + self._op_schema: onnx.defs.OpSchema | None + if op_schema is not None: + self._op_schema = op_schema + elif (op := opset[name]) is not None: + self._op_schema = op.op_schema + else: + self._op_schema = None self._signature: Optional[_schemas.OpSignature] = None if self._op_schema is None: @@ -484,18 +491,6 @@ def function_ir(self) -> irbuilder.IRFunction: return converter.translate_function_signature(func_ast) - @property - def op_schema(self) -> Optional[onnx.defs.OpSchema]: - """Return the OpSchema.""" - - if self._op_schema is not None: - return self._op_schema - - # FIXME(justinchuby): outputs are empty. Need to fix. - self._op_schema = _op_schema_from_function_ir(self.function_ir, self._opset) - - return self._op_schema - @property def op_signature(self) -> Optional[_schemas.OpSignature]: """Returns the signature of this op.""" diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index 66f6875eb2..ea8affc37d 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -421,7 +421,12 @@ def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: @classmethod def from_function( - cls, func, domain: str, name: str | None = None, overload: str = "", since_version: int = 1 + cls, + func, + domain: str, + name: str | None = None, + overload: str = "", + since_version: int = 1, ) -> OpSignature: """Produce an OpSignature from a function using type annotation.""" From 26151b25916528986b9b7d55955c2069d0cfda45 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 11:38:52 -0800 Subject: [PATCH 06/10] Fixes Signed-off-by: Justin Chu --- onnxscript/_internal/evaluator.py | 19 ++++--------------- onnxscript/_internal/evaluator_test.py | 6 ++++-- onnxscript/_internal/values.py | 26 +++++++------------------- onnxscript/ir/_schemas.py | 2 +- 4 files changed, 16 insertions(+), 37 deletions(-) diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index d7e9d5b775..1f29f15d68 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -123,22 +123,11 @@ def _unwrap_tensors_in_kwargs(kwargs: Mapping[str, Any]) -> dict[str, Any]: @runtime_checkable class Evaluator(Protocol): - """Protocol for evaluating ONNX ops.""" + """Protocol for evaluating ONNX ops. - def eval( - self, - schema: onnx.defs.OpSchema, - inputs: Sequence[ExtendedModeValue], - attributes: Mapping[str, Any], - ): - """Evaluates an ONNX op. - - Args: - schema: The OpSchema of the operator to evaluate. - inputs: The ONNX inputs to the op. - attributes: The ONNX attributes to the op. - """ - # Deprecated. Implement eval_op instead + NOTE: The ``eval`` method was deprecated and removed. Implement ``eval_op`` + and ``eval_function`` instead. + """ def eval_op( self, diff --git a/onnxscript/_internal/evaluator_test.py b/onnxscript/_internal/evaluator_test.py index c696ddf9b4..4949c04675 100644 --- a/onnxscript/_internal/evaluator_test.py +++ b/onnxscript/_internal/evaluator_test.py @@ -31,11 +31,13 @@ def square(y: FLOAT["N"]) -> FLOAT["N"]: # noqa: F821 np.testing.assert_equal(output, expected) # Test using ort-mixed-evaluator - output = seq_map[evaluator.ort_mixed_evaluator](x) + with evaluator.default_as(evaluator.ort_mixed_evaluator): + output = seq_map(x) np.testing.assert_equal(output, expected) # Test using ort-evaluator - output = seq_map[evaluator.ort_evaluator](x) + with evaluator.default_as(evaluator.ort_evaluator): + output = seq_map(x) np.testing.assert_equal(output, expected) diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 015fa4875f..e964078a05 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -300,7 +300,9 @@ def __init__( self.function_ir = irfun self.source = source self.kwargs = kwargs - self._op_schema: Optional[onnx.defs.OpSchema] = None + self._signature = _schemas.OpSignature.from_function( + self.function, domain=self.function_ir.domain, name=self.name + ) # Allow the object to be inspected as a function functools.update_wrapper(self, pyfun) @@ -311,15 +313,6 @@ def __init__( @property def op_signature(self) -> Optional[_schemas.OpSignature]: """Returns the signature of this op.""" - if self._signature is not None: - return self._signature - - if self.op_schema is None: - return None - - self._signature = _schemas.OpSignature.from_function( - self.function, domain=self.function_ir.domain, name=self.name - ) return self._signature @op_signature.setter @@ -400,6 +393,7 @@ def _to_model_proto( # No need to collect opsets from functions + # FIXME: Collect used opsets from the function nodes if "" not in opsets: # No operator is using the standard opset. # Use the specified version if provided or the default value. @@ -462,6 +456,9 @@ class TracedOnnxFunction(Op): def __init__(self, opset: Opset, func: Callable): super().__init__(opset, func.__name__) self.func = func + self._signature = _schemas.OpSignature.from_function( + self.func, domain="_traced", name=self.name + ) # Allow the object to be inspected as a function functools.update_wrapper(self, func) @@ -494,15 +491,6 @@ def function_ir(self) -> irbuilder.IRFunction: @property def op_signature(self) -> Optional[_schemas.OpSignature]: """Returns the signature of this op.""" - if self._signature is not None: - return self._signature - - if self.op_schema is None: - return None - - self._signature = _schemas.OpSignature.from_function( - self.func, domain="_traced", name=self.name - ) return self._signature @op_signature.setter diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index ea8affc37d..6d3a20bbed 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -441,7 +441,7 @@ def from_function( for param in py_signature.parameters.values(): if param.name not in type_hints: - logger.warning( + logger.debug( "Missing annotation for parameter '%s' from %s. Treating as an Input.", param.name, py_signature, From 0fe2f6de68e6996694bf82f0fc67309d6a86f3e0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 11:43:44 -0800 Subject: [PATCH 07/10] Fix converter Signed-off-by: Justin Chu --- onnxscript/_internal/converter.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index 6ab228ef4d..1c1c0963ad 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -12,7 +12,6 @@ Union, ) -import onnx import onnx_ir as ir import onnxscript @@ -29,6 +28,7 @@ from onnxscript._internal import ( type_annotation as ta, ) +from onnxscript.ir import _schemas logger = logging.getLogger("onnxscript") @@ -518,7 +518,7 @@ def _translate_attr( self, attr_name: str, expr: ast.AST, - attr_meta: onnx.defs.OpSchema.Attribute | None = None, + attr_meta: _schemas.AttributeParameter | None = None, ) -> ir.Attr | None: """Translate an attribute-value specification of the form `attr_name=` in a call to an op. expr is an AST. The following cases are supported: @@ -880,14 +880,11 @@ def _translate_call_expr( op_signature, node.args, kwargs, fill_defaults=False ) args = [self._translate_opt_expr(x) for x in args] - attrs = [ - self._translate_attr(x, y, callee.op_schema.attributes[x]) - for x, y in attrs.items() - ] + attrs = [self._translate_attr(x, y, op_signature.get(x)) for x, y in attrs.items()] else: args = [self._translate_opt_expr(x) for x in node.args] attrs = [self._translate_attr(x.arg, x.value) for x in node.keywords] - args = autocast.static_cast_inputs(self, callee.op_signature, args) + args = autocast.static_cast_inputs(self, op_signature, args) # In ONNX, there is no way to explicitly specify a None value for an attribute. # Instead, the attribute must be omitted from the attribute list. From 0e067b6932fa34a48e4095d4e734b704ce9393ac Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 11:49:11 -0800 Subject: [PATCH 08/10] homogeneous Signed-off-by: Justin Chu --- onnxscript/_internal/autocast.py | 4 +++- onnxscript/ir/_schemas.py | 9 ++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index 59732d0c08..99d0e82f5b 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -161,7 +161,9 @@ def cast_inputs( expected = expected_inputs[i] elif expected_inputs[-1].variadic: expected = expected_inputs[-1] - # TODO(justinchuby): Handle is_homogeneous params + if not expected.homogeneous: + args_typevars.append((x, None)) + continue else: raise ValueError( f"Number of actual parameters {len(args)} " diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index 6d3a20bbed..1f14634834 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -106,8 +106,10 @@ class Parameter: type_constraint: TypeConstraintParam required: bool variadic: bool + homogeneous: bool = True + min_arity: int = 1 + # TODO: Add differentiation_category default: Any = _EMPTY_DEFAULT - # TODO: Add other properties too def __str__(self) -> str: type_str = self.type_constraint.name @@ -188,6 +190,8 @@ def _convert_formal_parameter( type_constraint=type_constraint, required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional, variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic, + homogeneous=param.is_homogeneous, + min_arity=param.min_arity, ) @@ -455,6 +459,7 @@ def from_function( required=param.default is inspect.Parameter.empty, # TODO: Handle variadic variadic=False, + homogeneous=True, default=param.default if param.default is not inspect.Parameter.empty else _EMPTY_DEFAULT, @@ -505,6 +510,7 @@ def from_function( required=param.default is inspect.Parameter.empty, # TODO: Handle variadic variadic=False, + homogeneous=True, default=param.default if param.default is not inspect.Parameter.empty else _EMPTY_DEFAULT, @@ -542,6 +548,7 @@ def from_function( type_constraint=type_constraint, required=True, variadic=False, + homogeneous=True, default=_EMPTY_DEFAULT, ) ) From 1109f261196a21ba93a0b36fe4e280bb82153853 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 12:07:11 -0800 Subject: [PATCH 09/10] fix call functions Signed-off-by: Justin Chu --- onnxscript/_internal/irbuilder.py | 4 ++-- onnxscript/_internal/values.py | 26 ++++++++++++-------------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/onnxscript/_internal/irbuilder.py b/onnxscript/_internal/irbuilder.py index f287b6b1ab..1ae3c7bdb1 100644 --- a/onnxscript/_internal/irbuilder.py +++ b/onnxscript/_internal/irbuilder.py @@ -77,7 +77,7 @@ def append_parameter(self, parameter: ir.Value | ir.Attr) -> None: def add_nested_function(self, fun: IRFunction) -> None: self.nested_functions[fun.name] = fun - def get_called_functions(self) -> dict[str, ir.Function]: + def get_called_functions(self) -> dict[str, values.OnnxFunction]: called_functions: dict[str, values.OnnxFunction] = {} def visit(function_ir: IRFunction): @@ -94,7 +94,7 @@ def add(f: values.OnnxFunction): visit(self) - return {name: f.function_ir for name, f in called_functions.items()} + return called_functions def to_graph_proto(self) -> onnx.GraphProto: """Converts this instance into a `onnx.GraphProto`.""" diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index e964078a05..08dc6d32fc 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -11,7 +11,6 @@ import logging import types import typing -from collections.abc import Collection from typing import ( Any, Callable, @@ -351,7 +350,6 @@ def to_model_proto(self, **kwargs): def _to_model_proto( self, - functions: Collection[ir.Function] | None = None, io_types: Optional[ONNXType] = None, input_types: Optional[Sequence[ONNXType]] = None, output_types: Optional[Sequence[ONNXType]] = None, @@ -380,24 +378,24 @@ def _to_model_proto( An instance of :class:`onnx.ModelProto`. """ # Identify functions to include in the model - if functions is None: - sub_functions = self.function_ir.get_called_functions() - functions = sub_functions.values() + sub_functions = self.function_ir.get_called_functions() + functions = sub_functions.values() # Determine opset imports - opsets = self.function_ir.graph.opset_imports + opset_imports = self.function_ir.graph.opset_imports for func in functions: - if func.domain not in opsets: - opsets[func.domain] = 1 + domain = func.opset.domain + if domain is not None and domain not in opset_imports: + opset_imports[domain] = func.opset.version - # No need to collect opsets from functions + if "" not in opset_imports and "" in func.function_ir.opset_imports: + opset_imports[""] = func.function_ir.opset_imports[""] - # FIXME: Collect used opsets from the function nodes - if "" not in opsets: + if "" not in opset_imports: # No operator is using the standard opset. # Use the specified version if provided or the default value. - opsets[""] = ( + opset_imports[""] = ( opset_version if opset_version is not None else onnx.defs.onnx_opset_version() ) @@ -405,12 +403,12 @@ def _to_model_proto( if "ir_version" in kwargs: ir_version = kwargs.pop("ir_version") else: - ir_version = select_ir_version(opsets[""]) + ir_version = select_ir_version(opset_imports[""]) # Create the model model = ir.Model(self.function_ir.graph, ir_version=ir_version) for func in functions: - model.functions[func.identifier()] = func + model.functions[func.function_ir.identifier()] = func.function_ir model_proto = ir.to_proto(model) From 8bd5f52898318091561cce315116ac898e3c4442 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 16 Jan 2026 13:38:07 -0800 Subject: [PATCH 10/10] copilot Signed-off-by: Justin Chu --- onnxscript/_internal/evaluator.py | 10 ++++++++-- onnxscript/_internal/values.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index 1f29f15d68..d6075d29be 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -506,8 +506,14 @@ def _call_ort( return [_numpy_to_onnxscript_value(x) for x in result] -def _op_identifier(schema) -> tuple[str, str, int]: - return schema.name, schema.domain, schema.since_version +def _op_identifier( + op_schema_or_signature: onnx.defs.OpSchema | _schemas.OpSignature, +) -> tuple[str, str, int]: + return ( + op_schema_or_signature.name, + op_schema_or_signature.domain, + op_schema_or_signature.since_version, + ) class ORTEvaluator(BaseEvaluator): diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 08dc6d32fc..007b50a036 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -382,7 +382,7 @@ def _to_model_proto( functions = sub_functions.values() # Determine opset imports - opset_imports = self.function_ir.graph.opset_imports + opset_imports = self.function_ir.graph.opset_imports.copy() for func in functions: domain = func.opset.domain