Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions onnxscript/_internal/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import numpy as np
import onnx
import onnx.helper # noqa: TID251
from onnx.defs import OpSchema

from onnxscript import ir, tensor
from onnxscript.ir import _schemas

if TYPE_CHECKING:
from onnxscript._internal import converter
Expand Down Expand Up @@ -126,7 +126,7 @@ def cast_pyvalue_to_os_tensor(pyvalue, dtype=None):
def cast_inputs(
get_type_info: Callable[[Any], Any],
cast: Callable[[Any, Any], Any],
op_schema: OpSchema | None,
op_signature: _schemas.OpSignature | None,
args,
) -> tuple[Any, ...]:
"""Uses schema specification to support a limited form of auto-casting.
Expand All @@ -140,12 +140,15 @@ def cast_inputs(
This is used by the converter in a static-mode, as well as by the eager-mode
execution in a dynamic-mode.
"""
if op_schema is None:
if op_signature is None:
# Either an error or a custom op.
# No checks/casts in this case.
return tuple(cast(x, None) for x in args)

expected_inputs = op_schema.inputs
# Filter to get only input parameters (not AttributeParameters)
expected_inputs = [
param for param in op_signature.params if isinstance(param, _schemas.Parameter)
]
# We make two passes. In the first pass, we identify known type-bindings for
# type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}.
# In the second pass, we use these bindings to cast scalar-values to
Expand All @@ -156,17 +159,15 @@ def cast_inputs(
for i, x in enumerate(args):
if i < len(expected_inputs):
expected = expected_inputs[i]
elif expected_inputs[-1].option == OpSchema.FormalParameterOption.Variadic:
elif expected_inputs[-1].variadic:
expected = expected_inputs[-1]
if not expected.is_homogeneous:
args_typevars.append((x, None))
continue
# TODO(justinchuby): Handle is_homogeneous params
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This TODO indicates that the previous is_homogeneous handling was removed but needs to be reimplemented. The old code had logic to handle non-homogeneous variadic parameters. This unfinished work could lead to bugs if non-homogeneous variadic parameters are encountered.

Copilot uses AI. Check for mistakes.
else:
raise ValueError(
f"Number of actual parameters {len(args)} "
f"exceeds number of formal parameters {len(expected_inputs)}."
)
typevar = expected.type_str
typevar = expected.type_constraint.name
if "(" not in typevar:
# typevar is an identifier, like "T"
typeinfo = get_type_info(x)
Expand All @@ -177,18 +178,18 @@ def cast_inputs(
return tuple(cast_args)


def dynamic_cast_inputs(op_schema: OpSchema, args):
def dynamic_cast_inputs(op_signature: _schemas.OpSignature, args):
"""Used for autocast during eager-mode execution."""

def get_type_info(x):
return x.dtype if isinstance(x, tensor.Tensor) else None

return cast_inputs(get_type_info, cast_pyvalue_to_os_tensor, op_schema, args)
return cast_inputs(get_type_info, cast_pyvalue_to_os_tensor, op_signature, args)


def static_cast_inputs(
converter_: converter.Converter,
op_schema: Optional[OpSchema],
op_signature: Optional[_schemas.OpSignature],
args: Sequence[Optional[ir.Value]],
) -> tuple[str, ...]:
"""Used for autocast during script-translation.
Expand All @@ -212,4 +213,4 @@ def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]:
return converter_.emit1([x_cast], "CastLike", [x, y])
return x

return cast_inputs(get_type_info, cast_like, op_schema, args)
return cast_inputs(get_type_info, cast_like, op_signature, args)
6 changes: 3 additions & 3 deletions onnxscript/_internal/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ def _translate_call_expr(
else:
args = [self._translate_opt_expr(x) for x in node.args]
attrs = [self._translate_attr(x.arg, x.value) for x in node.keywords]
args = autocast.static_cast_inputs(self, callee.op_schema, args)
args = autocast.static_cast_inputs(self, callee.op_signature, args)

# In ONNX, there is no way to explicitly specify a None value for an attribute.
# Instead, the attribute must be omitted from the attribute list.
Expand All @@ -896,8 +896,8 @@ def _translate_call_expr(
return callee, args, attrs

def _cast_like_binary_expression(self, op, left, right) -> tuple[ir.Value, ir.Value]:
schema = op.op_schema
return autocast.static_cast_inputs(self, schema, (left, right))
op_signature = op.op_signature
return autocast.static_cast_inputs(self, op_signature, (left, right))

def _translate_binary_op_expr(self, node: ast.BinOp):
op = type(node.op)
Expand Down
35 changes: 18 additions & 17 deletions onnxscript/_internal/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,39 +188,39 @@
inputs: The ONNX inputs to the op.
attributes: The ONNX attributes to the op.
"""
op_signature = _schemas.OpSignature.from_op_schema(schema)
attributes = _unwrap_tensors_in_kwargs(attributes)
attributes, closure = self.adapt_attributes(schema, attributes)
inputs = self.adapt_inputs(schema, inputs)
attributes, closure = self._adapt_attributes(op_signature, attributes)
inputs = self._adapt_inputs(op_signature, inputs)
outputs = self._eval(schema, inputs, attributes, closure)
return self.adapt_outputs(schema, outputs)
return self._adapt_outputs(outputs)

def adapt_inputs(self, schema: onnx.defs.OpSchema, inputs: Sequence[ExtendedModeValue]):
def _adapt_inputs(
self, op_signature: _schemas.OpSignature, inputs: Sequence[ExtendedModeValue]
):
"""Transform inputs to the expected format for the evaluator.
Enables some syntactic sugar, such as the use of Python scalars,
in a manner consistent with the translator. See autocast.py for details.
"""
return autocast.dynamic_cast_inputs(schema, inputs)
return autocast.dynamic_cast_inputs(op_signature, inputs)

def adapt_attributes(
self, schema: onnx.defs.OpSchema, attributes: Mapping[str, ExtendedModeValue]
def _adapt_attributes(
self, op_signature, attributes: Mapping[str, ExtendedModeValue]

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning

Unused argument 'op_signature' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument
) -> tuple[dict[str, ExtendedModeValue], dict[str, ExtendedModeValue]]:
"""Transform attributes to the expected format for the evaluator.
Returns:
A closure that can be used to evaluate graph-valued attributes.
"""
use_graph_attribute = self.use_graph_attribute(schema)
use_graph_attribute = self.use_graph_attribute(op_singature)

Check warning on line 216 in onnxscript/_internal/evaluator.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "singature" is a misspelling of "signature" Raw Output: ./onnxscript/_internal/evaluator.py:216:58: "singature" is a misspelling of "signature"

Check failure

Code scanning / lintrunner

PYLINT/E0602 Error

Undefined variable 'op_singature' (undefined-variable)
See undefined-variable. To disable, use # pylint: disable=undefined-variable

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name op_singature.
See https://docs.astral.sh/ruff/rules/undefined-name

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning

Unused variable 'use_graph_attribute' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable use_graph_attribute is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
closure: dict[Any, Any] = {}
adapted_attributes = {}
for k, v in attributes.items():
if isinstance(v, values.OnnxClosure):
if use_graph_attribute:
adapted_attributes[k] = v.function_ir.to_graph_proto()
for pyvar, onnxvar in v.function_ir.outer_scope_variables:
closure[onnxvar.value.name] = v.frame.f_locals[pyvar]
else:
adapted_attributes[k] = v.function
adapted_attributes[k] = v.function_ir.to_graph_proto()
for pyvar, onnxvar in v.function_ir.outer_scope_variables:
closure[onnxvar.value.name] = v.frame.f_locals[pyvar]
Comment on lines +221 to +223
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removal of the use_graph_attribute conditional means that _adapt_attributes now always converts OnnxClosure to graph proto format. This is a behavioral change that affects the ORTMixedEvaluator class, which previously could use the function directly for registered python ops instead of graph proto. This could break existing functionality where python implementations were registered via ORTMixedEvaluator.register().

Suggested change
adapted_attributes[k] = v.function_ir.to_graph_proto()
for pyvar, onnxvar in v.function_ir.outer_scope_variables:
closure[onnxvar.value.name] = v.frame.f_locals[pyvar]
# If the closure captures outer-scope variables, we must materialize
# a graph proto and populate the closure mapping so that evaluators
# can rebind these values when executing the attribute.
if v.function_ir.outer_scope_variables:
adapted_attributes[k] = v.function_ir.to_graph_proto()
for pyvar, onnxvar in v.function_ir.outer_scope_variables:
closure[onnxvar.value.name] = v.frame.f_locals[pyvar]
else:
# For closures without captured outer-scope variables, avoid
# forcing a graph-valued attribute so that evaluators such as
# ORTMixedEvaluator can continue to use the underlying function
# implementation directly (e.g., for registered python ops).
adapted_attributes[k] = v.function_ir

Copilot uses AI. Check for mistakes.
elif callable(v):
raise TypeError(
f"Error: function-valued attribute {v.__name__} has no graph_proto"
Expand All @@ -230,18 +230,19 @@
adapted_attributes[k] = v
return adapted_attributes, closure

def adapt_outputs(self, schema: onnx.defs.OpSchema, outputs: Sequence[EagerModeValue]):
def _adapt_outputs(self, outputs: Sequence[EagerModeValue]):
"""Adapt evaluator's output to convention used in onnxscript.
Onnxscript uses a tuple/sequence only when number of outputs > 1.
"""
del schema # unused
return outputs[0] if len(outputs) == 1 else outputs

def use_graph_attribute(self, schema: onnx.defs.OpSchema):

def use_graph_attribute(self, op_signature: _schemas.OpSignature) -> bool:

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning

Unused argument 'op_signature' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument
del schema # unused
return True


@abc.abstractmethod
def _eval(
self,
Expand Down
Loading