From 64a372fead2aaf17fdf98b912031f24f2956bddf Mon Sep 17 00:00:00 2001 From: abeeha123 Date: Fri, 7 Nov 2025 12:59:03 +0500 Subject: [PATCH] Fix: handle missing position_id buffer in Hugging Face model import --- .../torch/exported_program_translator.py | 85 ++++++------------- 1 file changed, 28 insertions(+), 57 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 4f3132b8d8f2..d3fbfd49faf1 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -701,23 +701,11 @@ def _select(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.take(x, index, dim)) def _slice(self, node: fx.Node) -> relax.Var: - import sys - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else 0 - start = node.args[2] if len(node.args) > 2 else None - end_val = node.args[3] if len(node.args) > 3 else None - step = node.args[4] if len(node.args) > 4 else 1 - - if start is None: - start = 0 - if end_val is None: - end_val = sys.maxsize - - axes = [dim] - begin = [start] - end = [end_val] - stride = [step] + axes = [node.args[1]] + begin = [node.args[2]] + end = [node.args[3]] + stride = [node.args[4] if len(node.args) > 4 else 1] return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) def _unflatten(self, node: fx.Node) -> relax.Var: @@ -772,14 +760,6 @@ def _zeros(self, node: fx.Node) -> relax.Var: ) return self.block_builder.emit(relax.op.zeros(size, dtype)) - def _scalar_tensor(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - scalar_value = args[0] - dtype = self._convert_data_type( - node.kwargs.get("dtype", torch.get_default_dtype()), self.env - ) - return self.block_builder.emit(relax.const(scalar_value, dtype)) - def _instance_norm(self, node: fx.Node): import numpy as np @@ -829,16 +809,9 @@ def create_convert_map( "cosh.default": self._unary_op(relax.op.cosh), "dropout.default": lambda node: self.env[node.args[0]], "dropout_.default": lambda node: self.env[node.args[0]], - "native_dropout.default": lambda node: self.env[node.args[0]], "elu.default": self._elu, "erf.default": self._unary_op(relax.op.erf), "exp.default": self._unary_op(relax.op.exp), - "expm1.default": lambda node: self.block_builder.emit( - relax.op.subtract( - relax.op.exp(self.env[node.args[0]]), - relax.const(1.0, self.env[node.args[0]].struct_info.dtype), - ) - ), "floor.default": self._unary_op(relax.op.floor), "gelu.default": self._gelu, "hardsigmoid.default": self._hardsigmoid, @@ -857,9 +830,7 @@ def create_convert_map( "log10.default": self._log10, "log1p.default": self._log1p, "logical_not.default": self._unary_op(relax.op.logical_not), - "logical_and.default": self._binary_op(relax.op.logical_and, operator.and_), "log_softmax.int": self._log_softmax, - "_log_softmax.default": self._log_softmax, "neg.default": self._unary_op(relax.op.negative), "pad.default": self._pad, "pixel_shuffle.default": self._pixel_shuffle, @@ -871,7 +842,6 @@ def create_convert_map( "relu6_.default": self._unary_op(relax.op.nn.relu6), "round.default": self._round, "rsqrt.default": self._unary_op(relax.op.rsqrt), - "scalar_tensor.default": self._scalar_tensor, "rsub.Tensor": self._rsub, "rsub.Scalar": self._rsub, "selu.default": self._unary_op(relax.op.nn.selu), @@ -882,7 +852,6 @@ def create_convert_map( "sin.default": self._unary_op(relax.op.sin), "sinh.default": self._unary_op(relax.op.sinh), "softmax.int": self._softmax, - "_softmax.default": self._softmax, "softplus.default": self._softplus, "softshrink.default": self._softshrink, "softsign.default": self._softsign, @@ -900,7 +869,6 @@ def create_convert_map( "bitwise_or.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_), "bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_), "bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_), - "div.Scalar": self._binary_op(relax.op.divide, operator.truediv), "div.Tensor": self._binary_op(relax.op.divide, operator.truediv), "div.Tensor_mode": self._div, "eq.Scalar": self._binary_op(relax.op.equal, operator.eq), @@ -1015,7 +983,6 @@ def create_convert_map( "flip.default": self._flip, "gather.default": self._gather, "index.Tensor": self._index_tensor, - "index_put.default": self._index_put, "index_put_.default": self._index_put, "meshgrid.indexing": self._meshgrid, "meshgrid.default": self._meshgrid, @@ -1031,7 +998,6 @@ def create_convert_map( "split_with_sizes.default": self._split, "squeeze.default": self._squeeze, "squeeze.dim": self._squeeze, - "squeeze.dims": self._squeeze, "stack.default": self._stack, "take.default": self._take, "tile.default": self._tile, @@ -1053,12 +1019,7 @@ def create_convert_map( "detach_.default": self._detach, "contiguous.default": lambda node: self.env[node.args[0]], # no-op "clone.default": lambda node: self.env[node.args[0]], - "bernoulli.p": lambda node: self.env[node.args[0]], # Dropout: just return input - "_assert_tensor_metadata.default": lambda node: self.env[ - node.args[0] - ], # metadata assertion: no-op "empty.memory_format": self._empty, - "empty_permuted.default": self._empty, # Similar to empty with permuted layout "empty_like.default": self._empty_like, "eye.default": self._eye, "eye.m": self._eye, @@ -1092,7 +1053,6 @@ def create_convert_map( # other "getitem": self._getitem, "item.default": self._item, - "_local_scalar_dense.default": self._item, } def create_input_vars( @@ -1102,7 +1062,13 @@ def create_input_vars( parameters_buffers_constants = OrderedDict() user_inputs = OrderedDict() torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {} - + + extra_buffers={ + "position_ids": {"shape":(1,128), "dtype":torch.int64}, + "token_type_ids": {"shape":(1.128), "dtype":torch.int64}, + } + merged_state=ChainMap(exported_program.state_dict,extra_buffers) + for spec in exported_program.graph_signature.input_specs: name_hint = spec.arg.name if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR: @@ -1116,8 +1082,21 @@ def create_input_vars( break else: # PARAMETER or BUFFER - torch_shape = exported_program.state_dict[spec.target].shape - torch_dtype = exported_program.state_dict[spec.target].dtype + info= None + if spec.target in merged_state: + info=merged_state[spec.target] + elif spec.target.split(".")[-1] in merged_state: + info = merged_state[spec.target.split(".")[-1]] + + if info is None: + raise KeyError(f"Missing target in state_dict or extra buffers: {spec.target}") + + # Handle both original and extra buffer + if hasattr(info,"shape") and hasattr(info,"dtype"): + torch_shape=info.shape + torch_dtype=info.dtype + + # TODO(mshr-h): Support range constraints relax_shape = [ @@ -1237,7 +1216,6 @@ def from_exported_program( keep_params_as_input: bool = False, unwrap_unit_return_tuple: bool = False, no_bind_return_tuple: bool = False, - run_ep_decomposition: bool = False, ) -> tvm.IRModule: """Convert a PyTorch ExportedProgram to a Relax program @@ -1257,12 +1235,6 @@ def from_exported_program( A boolean flag indicating whether to bind the return tuple as a relax var. If the flag is true and the return value is a tuple, it will not bind it to a var. - run_ep_decomposition : bool - A boolean flag indicating whether to run PyTorch's decomposition on the - exported program before translation. When True, high-level operators will - be decomposed into their constituent parts. Defaults to False for backward - compatibility. - Returns ------- output : tvm.IRModule @@ -1302,9 +1274,8 @@ def forward(self, input): # Use the importer to import the ExportedProgram to Relax. mod: tvm.IRModule = from_exported_program(exported_program) """ - # Conditionally decompose into Core ATen operators - if run_ep_decomposition: - exported_program = exported_program.run_decompositions() + # decompose into Core ATen operators + exported_program.run_decompositions() return ExportedProgramImporter().from_exported_program( exported_program,