Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 28 additions & 57 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,23 +701,11 @@ def _select(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.take(x, index, dim))

def _slice(self, node: fx.Node) -> relax.Var:
import sys

x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else 0
start = node.args[2] if len(node.args) > 2 else None
end_val = node.args[3] if len(node.args) > 3 else None
step = node.args[4] if len(node.args) > 4 else 1

if start is None:
start = 0
if end_val is None:
end_val = sys.maxsize

axes = [dim]
begin = [start]
end = [end_val]
stride = [step]
axes = [node.args[1]]
begin = [node.args[2]]
end = [node.args[3]]
stride = [node.args[4] if len(node.args) > 4 else 1]
return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride))

def _unflatten(self, node: fx.Node) -> relax.Var:
Expand Down Expand Up @@ -772,14 +760,6 @@ def _zeros(self, node: fx.Node) -> relax.Var:
)
return self.block_builder.emit(relax.op.zeros(size, dtype))

def _scalar_tensor(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
scalar_value = args[0]
dtype = self._convert_data_type(
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
)
return self.block_builder.emit(relax.const(scalar_value, dtype))

def _instance_norm(self, node: fx.Node):
import numpy as np

Expand Down Expand Up @@ -829,16 +809,9 @@ def create_convert_map(
"cosh.default": self._unary_op(relax.op.cosh),
"dropout.default": lambda node: self.env[node.args[0]],
"dropout_.default": lambda node: self.env[node.args[0]],
"native_dropout.default": lambda node: self.env[node.args[0]],
"elu.default": self._elu,
"erf.default": self._unary_op(relax.op.erf),
"exp.default": self._unary_op(relax.op.exp),
"expm1.default": lambda node: self.block_builder.emit(
relax.op.subtract(
relax.op.exp(self.env[node.args[0]]),
relax.const(1.0, self.env[node.args[0]].struct_info.dtype),
)
),
"floor.default": self._unary_op(relax.op.floor),
"gelu.default": self._gelu,
"hardsigmoid.default": self._hardsigmoid,
Expand All @@ -857,9 +830,7 @@ def create_convert_map(
"log10.default": self._log10,
"log1p.default": self._log1p,
"logical_not.default": self._unary_op(relax.op.logical_not),
"logical_and.default": self._binary_op(relax.op.logical_and, operator.and_),
"log_softmax.int": self._log_softmax,
"_log_softmax.default": self._log_softmax,
"neg.default": self._unary_op(relax.op.negative),
"pad.default": self._pad,
"pixel_shuffle.default": self._pixel_shuffle,
Expand All @@ -871,7 +842,6 @@ def create_convert_map(
"relu6_.default": self._unary_op(relax.op.nn.relu6),
"round.default": self._round,
"rsqrt.default": self._unary_op(relax.op.rsqrt),
"scalar_tensor.default": self._scalar_tensor,
"rsub.Tensor": self._rsub,
"rsub.Scalar": self._rsub,
"selu.default": self._unary_op(relax.op.nn.selu),
Expand All @@ -882,7 +852,6 @@ def create_convert_map(
"sin.default": self._unary_op(relax.op.sin),
"sinh.default": self._unary_op(relax.op.sinh),
"softmax.int": self._softmax,
"_softmax.default": self._softmax,
"softplus.default": self._softplus,
"softshrink.default": self._softshrink,
"softsign.default": self._softsign,
Expand All @@ -900,7 +869,6 @@ def create_convert_map(
"bitwise_or.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_),
"bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_),
"bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_),
"div.Scalar": self._binary_op(relax.op.divide, operator.truediv),
"div.Tensor": self._binary_op(relax.op.divide, operator.truediv),
"div.Tensor_mode": self._div,
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
Expand Down Expand Up @@ -1015,7 +983,6 @@ def create_convert_map(
"flip.default": self._flip,
"gather.default": self._gather,
"index.Tensor": self._index_tensor,
"index_put.default": self._index_put,
"index_put_.default": self._index_put,
"meshgrid.indexing": self._meshgrid,
"meshgrid.default": self._meshgrid,
Expand All @@ -1031,7 +998,6 @@ def create_convert_map(
"split_with_sizes.default": self._split,
"squeeze.default": self._squeeze,
"squeeze.dim": self._squeeze,
"squeeze.dims": self._squeeze,
"stack.default": self._stack,
"take.default": self._take,
"tile.default": self._tile,
Expand All @@ -1053,12 +1019,7 @@ def create_convert_map(
"detach_.default": self._detach,
"contiguous.default": lambda node: self.env[node.args[0]], # no-op
"clone.default": lambda node: self.env[node.args[0]],
"bernoulli.p": lambda node: self.env[node.args[0]], # Dropout: just return input
"_assert_tensor_metadata.default": lambda node: self.env[
node.args[0]
], # metadata assertion: no-op
"empty.memory_format": self._empty,
"empty_permuted.default": self._empty, # Similar to empty with permuted layout
"empty_like.default": self._empty_like,
"eye.default": self._eye,
"eye.m": self._eye,
Expand Down Expand Up @@ -1092,7 +1053,6 @@ def create_convert_map(
# other
"getitem": self._getitem,
"item.default": self._item,
"_local_scalar_dense.default": self._item,
}

def create_input_vars(
Expand All @@ -1102,7 +1062,13 @@ def create_input_vars(
parameters_buffers_constants = OrderedDict()
user_inputs = OrderedDict()
torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}


extra_buffers={
"position_ids": {"shape":(1,128), "dtype":torch.int64},
"token_type_ids": {"shape":(1.128), "dtype":torch.int64},
}
merged_state=ChainMap(exported_program.state_dict,extra_buffers)

for spec in exported_program.graph_signature.input_specs:
name_hint = spec.arg.name
if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
Expand All @@ -1116,8 +1082,21 @@ def create_input_vars(
break
else:
# PARAMETER or BUFFER
torch_shape = exported_program.state_dict[spec.target].shape
torch_dtype = exported_program.state_dict[spec.target].dtype
info= None
if spec.target in merged_state:
info=merged_state[spec.target]
elif spec.target.split(".")[-1] in merged_state:
info = merged_state[spec.target.split(".")[-1]]

if info is None:
raise KeyError(f"Missing target in state_dict or extra buffers: {spec.target}")

# Handle both original and extra buffer
if hasattr(info,"shape") and hasattr(info,"dtype"):
torch_shape=info.shape
torch_dtype=info.dtype



# TODO(mshr-h): Support range constraints
relax_shape = [
Expand Down Expand Up @@ -1237,7 +1216,6 @@ def from_exported_program(
keep_params_as_input: bool = False,
unwrap_unit_return_tuple: bool = False,
no_bind_return_tuple: bool = False,
run_ep_decomposition: bool = False,
) -> tvm.IRModule:
"""Convert a PyTorch ExportedProgram to a Relax program

Expand All @@ -1257,12 +1235,6 @@ def from_exported_program(
A boolean flag indicating whether to bind the return tuple as a relax var.
If the flag is true and the return value is a tuple, it will not bind it to a var.

run_ep_decomposition : bool
A boolean flag indicating whether to run PyTorch's decomposition on the
exported program before translation. When True, high-level operators will
be decomposed into their constituent parts. Defaults to False for backward
compatibility.

Returns
-------
output : tvm.IRModule
Expand Down Expand Up @@ -1302,9 +1274,8 @@ def forward(self, input):
# Use the importer to import the ExportedProgram to Relax.
mod: tvm.IRModule = from_exported_program(exported_program)
"""
# Conditionally decompose into Core ATen operators
if run_ep_decomposition:
exported_program = exported_program.run_decompositions()
# decompose into Core ATen operators
exported_program.run_decompositions()

return ExportedProgramImporter().from_exported_program(
exported_program,
Expand Down