From 5247e8cbdf93fb4b768534f17fe082ca75970033 Mon Sep 17 00:00:00 2001 From: LZY2006 <58110034+Lee-qian-gay@users.noreply.github.com> Date: Sun, 10 Jan 2021 12:34:44 +0800 Subject: [PATCH 01/12] Signed-off-by: LZY2006 --- .gitignore | 8 + data/torch_code.txt | 13651 ++++++++++++++++++++++++++++++++++++++++++ model.py | 31 +- out.txt | Bin 0 -> 672 bytes read_utils.py | 3 +- sample.py | 5 +- train.py | 20 +- 7 files changed, 13702 insertions(+), 16 deletions(-) create mode 100644 data/torch_code.txt create mode 100644 out.txt diff --git a/.gitignore b/.gitignore index 83f0b80..d74fccd 100644 --- a/.gitignore +++ b/.gitignore @@ -102,3 +102,11 @@ ENV/ # custom model/ + +sampling/ +training/ + +data/Bilibili.txt +data/TheThreeBodyProblem.txt +data/WangZengQi.txt +data/ZhaoHuaXiShi.txt \ No newline at end of file diff --git a/data/torch_code.txt b/data/torch_code.txt new file mode 100644 index 0000000..7855c8f --- /dev/null +++ b/data/torch_code.txt @@ -0,0 +1,13651 @@ +from typing import ( + Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING +) import torch +import torch.nn.functional as F +from torch.types import _size +from ._lowrank import svd_lowrank, pca_lowrank +from .overrides import has_torch_function, handle_torch_function +from ._jit_internal import boolean_dispatch, List +from ._jit_internal import _overload as overload Tensor = torch.Tensor +from torch import _VF __all__ = [ + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', +] +def broadcast_tensors(*tensors): + r + if not torch.jit.is_scripting(): + if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): + return handle_torch_function(broadcast_tensors, tensors, *tensors) + return _VF.broadcast_tensors(tensors) +def split(tensor, split_size_or_sections, dim=0): + r + if not torch.jit.is_scripting(): + if type(tensor) is not Tensor and has_torch_function((tensor,)): + return handle_torch_function(split, (tensor,), tensor, split_size_or_sections, + dim=dim) + + + + + return tensor.split(split_size_or_sections, dim) +if TYPE_CHECKING: + _Indices = _size +else: + _Indices = List[int] def _indices_product(indices: _Indices) -> List[List[int]]: + empty_list = torch.jit.annotate(List[int], []) + result = [empty_list] + for idx in indices: + result_temp = torch.jit.annotate(List[List[int]], []) + for res in result: + for i in range(idx): + result_temp.append(res + [i]) + result = result_temp + return result +def _index_tensor_with_indices_list(tensor, indices): + + out = tensor + for index in indices: + out = out[index] + return out +def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): + + r + if not torch.jit.is_scripting(): + tens_ops = (LU_data, LU_pivots) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + lu_unpack, tens_ops, LU_data, LU_pivots, unpack_data=unpack_data, + unpack_pivots=unpack_pivots) + shape = LU_data.shape + + + + + + + m, n = shape[-2:] + k = min(m, n) + if unpack_data: + U: Optional[Tensor] = LU_data.triu() + assert U is not None + if m != k: + U = U.narrow(-2, 0, k) + L: Optional[Tensor] = LU_data.tril() + assert L is not None + if k != n: + L = L.narrow(-1, 0, k) + L.diagonal(dim1=-2, dim2=-1).fill_(1) + else: + L = U = None if unpack_pivots: + LU_pivots_zero_idx = LU_pivots - 1 + if LU_data.dim() > 2: + P: Optional[Tensor] = torch.eye(m, device=LU_data.device, + dtype=LU_data.dtype) \ + .expand(shape[:-1] + (m,)) \ + .clone(memory_format=torch.contiguous_format) + assert P is not None indices = _indices_product(shape[:-2]) + for idx in indices: + final_order = [i for i in range(m)] + for k, j in enumerate(_index_tensor_with_indices_list(LU_pivots_zero_idx, idx)): + final_order[k], final_order[j] = final_order[j], final_order[k] + p_idx = _index_tensor_with_indices_list(P, idx) + p_idx.copy_(p_idx.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))) + else: + P = torch.eye(m, device=LU_data.device, dtype=LU_data.dtype) + final_order = [i for i in range(m)] + for k, j, in enumerate(LU_pivots_zero_idx): + final_order[k], final_order[j] = final_order[j], final_order[k] + P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device)) + else: + P = None return P, L, U +def einsum(equation, *operands): + r + if not torch.jit.is_scripting(): + if any(type(t) is not Tensor for t in operands) and has_torch_function(operands): + return handle_torch_function(einsum, operands, equation, *operands) + if len(operands) == 1 and isinstance(operands[0], (list, tuple)): + _operands = operands[0] + return einsum(equation, *_operands) return _VF.einsum(equation, operands) +if TYPE_CHECKING: + + def meshgrid(*tensors: Union[Tensor, List[Tensor]]) -> Tuple[Tensor, ...]: + return _meshgrid(*tensors) +else: + def meshgrid(*tensors): + return _meshgrid(*tensors) +def _meshgrid(*tensors): + r + if not torch.jit.is_scripting(): + if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): + return handle_torch_function(meshgrid, tensors, *tensors) + if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)): + tensors = tensors[0] + return _VF.meshgrid(tensors) +def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, + win_length: Optional[int] = None, window: Optional[Tensor] = None, + center: bool = True, pad_mode: str = 'MSG', normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None) -> Tensor: + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=center, pad_mode=pad_mode, normalized=normalized, + onesided=onesided, return_complex=return_complex) + + + if center: + signal_dim = input.dim() + extended_shape = [1] * (3 - signal_dim) + list(input.size()) + pad = int(n_fft // 2) + input = F.pad(input.view(extended_shape), (pad, pad), pad_mode) + input = input.view(input.shape[-signal_dim:]) + return _VF.stft(input, n_fft, hop_length, win_length, window, + normalized, onesided, return_complex) def istft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, + win_length: Optional[int] = None, window: Optional[Tensor] = None, + center: bool = True, normalized: bool = False, + onesided: Optional[bool] = None, length: Optional[int] = None, + return_complex: bool = False) -> Tensor: + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=center, normalized=normalized, onesided=onesided, + length=length, return_complex=return_complex) return _VF.istft(input, n_fft, hop_length, win_length, window, center, + normalized, onesided, length, return_complex) +del torch.unique_dim +if TYPE_CHECKING: + + + + _unique_impl_out = Any +else: + _unique_impl_out = Tuple[Tensor, Tensor, Tensor] +def _unique_impl(input: Tensor, sorted: bool = True, + return_inverse: bool = False, return_counts: bool = False, + dim: Optional[int] = None) -> _unique_impl_out: + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + unique, (input,), input, sorted=sorted, return_inverse=return_inverse, + return_counts=return_counts, dim=dim) if dim is not None: + output, inverse_indices, counts = _VF.unique_dim( + input, + dim, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + ) + else: + output, inverse_indices, counts = torch._unique2( + input, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + ) + return output, inverse_indices, counts +def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False, + return_counts: bool = False, + dim: Optional[int] = None) -> _unique_impl_out: + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + unique_consecutive, (input,), input, return_inverse=return_inverse, + return_counts=return_counts, dim=dim) + output, inverse_indices, counts = _VF.unique_consecutive( + input, return_inverse=return_inverse, return_counts=return_counts, dim=dim) + return output, inverse_indices, counts +def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None): + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim) + return output, counts +def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None): + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim) + return output +def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None): + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim) + return output, inverse_indices +_return_inverse_false = boolean_dispatch( + arg_name='MSG', + arg_index=3, + default=False, + if_true=_return_counts, + if_false=_return_output, + module_name=__name__, + func_name='MSG') _return_inverse_true = boolean_dispatch( + arg_name='MSG', + arg_index=3, + default=False, + if_true=_unique_impl, + if_false=_return_inverse, + module_name=__name__, + func_name='MSG') +unique = boolean_dispatch( + arg_name='MSG', + arg_index=2, + default=False, + if_true=_return_inverse_true, + if_false=_return_inverse_false, + module_name=__name__, + func_name='MSG') +unique.__doc__ = _unique_impl.__doc__ +def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None): + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim) + return output, counts +def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None): + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim) + return output +def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None): + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim) + return output, inverse_indices +_consecutive_return_inverse_false = boolean_dispatch( + arg_name='MSG', + arg_index=1, + default=False, + if_true=_consecutive_return_counts, + if_false=_consecutive_return_output, + module_name=__name__, + func_name='MSG') _consecutive_return_inverse_true = boolean_dispatch( + arg_name='MSG', + arg_index=1, + default=False, + if_true=_unique_consecutive_impl, + if_false=_consecutive_return_inverse, + module_name=__name__, + func_name='MSG') +unique_consecutive = boolean_dispatch( + arg_name='MSG', + arg_index=2, + default=False, + if_true=_consecutive_return_inverse_true, + if_false=_consecutive_return_inverse_false, + module_name=__name__, + func_name='MSG') +unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__ +def tensordot(a, b, dims=2): + r + if not torch.jit.is_scripting(): + if (type(a) is not Tensor or type(b) is not Tensor) and has_torch_function((a, b)): + return handle_torch_function(tensordot, (a, b), a, b, dims=dims) + if isinstance(dims, (list, tuple)) or \ + (isinstance(dims, torch.Tensor) and dims.numel() > 1): + dims_a, dims_b = dims + else: + if isinstance(dims, torch.Tensor): + dims = dims.item() + if dims < 0: + raise RuntimeError(f"MSG") + dims_a = list(range(-dims, 0)) + dims_b = list(range(dims)) + return _VF.tensordot(a, b, dims_a, dims_b) def cartesian_prod(*tensors): + + if not torch.jit.is_scripting(): + if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): + return handle_torch_function(cartesian_prod, tensors, *tensors) + return _VF.cartesian_prod(tensors) def block_diag(*tensors): + + if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): + return handle_torch_function(block_diag, tensors, *tensors) + return torch._C._VariableFunctions.block_diag(tensors) +def cdist(x1, x2, p=2., compute_mode='MSG'): + + r + if not torch.jit.is_scripting(): + if (type(x1) is not Tensor or type(x2) is not Tensor) and has_torch_function((x1, x2)): + return handle_torch_function( + cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode) + if compute_mode == 'MSG': + return _VF.cdist(x1, x2, p, None) + elif compute_mode == 'MSG': + return _VF.cdist(x1, x2, p, 1) + elif compute_mode == 'MSG': + return _VF.cdist(x1, x2, p, 2) + else: + raise ValueError(f"MSG") def atleast_1d(*tensors): + r + if not torch.jit.is_scripting(): + if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): + return handle_torch_function(atleast_1d, tensors, *tensors) + if len(tensors) == 1: + tensors = tensors[0] + return _VF.atleast_1d(tensors) def atleast_2d(*tensors): + r + if not torch.jit.is_scripting(): + if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): + return handle_torch_function(atleast_2d, tensors, *tensors) + if len(tensors) == 1: + tensors = tensors[0] + return _VF.atleast_2d(tensors) def atleast_3d(*tensors): + r + if not torch.jit.is_scripting(): + if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): + return handle_torch_function(atleast_3d, tensors, *tensors) + if len(tensors) == 1: + tensors = tensors[0] + return _VF.atleast_3d(tensors) +if TYPE_CHECKING: + pass + + + + + + + + + + +else: + + + @overload + def norm(input, p="MSG", dim=None, keepdim=False, out=None, dtype=None): + pass @overload + def norm(input, p="MSG", dim=None, keepdim=False, out=None, dtype=None): + pass @overload + def norm(input, p="MSG", dim=None, keepdim=False, out=None, dtype=None): + pass @overload + def norm(input, p="MSG", dim=None, keepdim=False, out=None, dtype=None): + pass +def norm(input, p="MSG", dim=None, keepdim=False, out=None, dtype=None): + r if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype) ndim = input.dim() + if dim is None and out is None and dtype is None and p is not None: + if isinstance(p, str): + if p == "MSG": + return _VF.frobenius_norm(input, dim=(), keepdim=keepdim) + if not isinstance(p, str): + _dim = [i for i in range(ndim)] + return _VF.norm(input, p, dim=_dim, keepdim=keepdim) + + + if dim is not None: + if isinstance(dim, int): + _dim = [dim] + else: + _dim = dim + else: + _dim = None if isinstance(p, str): + if p == "MSG": + if dtype is not None: + raise ValueError("MSG") if _dim is None: + _dim = [i for i in range(ndim)] + if out is None: + return _VF.frobenius_norm(input, _dim, keepdim=keepdim) + else: + return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out) + elif p == "MSG": + if dtype is not None: + raise ValueError("MSG") + if _dim is None: + if out is None: + return _VF.nuclear_norm(input, keepdim=keepdim) + else: + return _VF.nuclear_norm(input, keepdim=keepdim, out=out) + else: + if out is None: + return _VF.nuclear_norm(input, _dim, keepdim=keepdim) + else: + return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out) + raise RuntimeError(f"MSG") + else: + if _dim is None: + _dim = [i for i in range(ndim)] if out is None: + if dtype is None: + return _VF.norm(input, p, _dim, keepdim=keepdim) + else: + return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype) + else: + if dtype is None: + return _VF.norm(input, p, _dim, keepdim=keepdim, out=out) + else: + return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) def chain_matmul(*matrices): + r + if not torch.jit.is_scripting(): + if any(type(t) is not Tensor for t in matrices) and has_torch_function(matrices): + return handle_torch_function(chain_matmul, matrices, *matrices) + return _VF.chain_matmul(matrices) +def _lu_impl(A, pivot=True, get_infos=False, out=None): + + r + + return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos)) +if TYPE_CHECKING: + _ListOrSeq = Sequence[Tensor] +else: + _ListOrSeq = List[Tensor] def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None: + get_infos_int = 1 if get_infos else 0 + if out_len - get_infos_int != 2: + raise TypeError(f"MSG") + if not isinstance(out, (tuple, list)): + raise TypeError(f"MSG") def _lu_with_infos(A, pivot=True, get_infos=False, out=None): + + if not torch.jit.is_scripting(): + if type(A) is not Tensor and has_torch_function((A,)): + return handle_torch_function( + lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) + result = _lu_impl(A, pivot, get_infos, out) + if out is not None: + _check_list_size(len(out), get_infos, out) + for i in range(len(out)): + out[i].resize_as_(result[i]).copy_(result[i]) + return out + else: + return result def _lu_no_infos(A, pivot=True, get_infos=False, out=None): + + + if not torch.jit.is_scripting(): + if type(A) is not Tensor and has_torch_function((A,)): + return handle_torch_function( + lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) + result = _lu_impl(A, pivot, get_infos, out) + if out is not None: + _check_list_size(len(out), get_infos, out) + for i in range(len(out)): + out[i].resize_as_(result[i]).copy_(result[i]) + return out + else: + return result[0], result[1] lu = boolean_dispatch( + arg_name='MSG', + arg_index=2, + default=False, + if_true=_lu_with_infos, + if_false=_lu_no_infos, + module_name=__name__, + func_name='MSG') +lu.__doc__ = _lu_impl.__doc__ def align_tensors(*tensors): + raise RuntimeError('MSG') +import errno +import hashlib +import os +import re +import shutil +import sys +import tempfile +import torch +import warnings +import zipfile from urllib.request import urlopen, Request +from urllib.parse import urlparse try: + from tqdm.auto import tqdm +except ImportError: + try: + from tqdm import tqdm + except ImportError: + class tqdm(object): def __init__(self, total=None, disable=False, + unit=None, unit_scale=None, unit_divisor=None): + self.total = total + self.disable = disable + self.n = 0 + def update(self, n): + if self.disable: + return self.n += n + if self.total is None: + sys.stderr.write("MSG".format(self.n)) + else: + sys.stderr.write("MSG".format(100 * self.n / float(self.total))) + sys.stderr.flush() def __enter__(self): + return self def __exit__(self, exc_type, exc_val, exc_tb): + if self.disable: + return sys.stderr.write('MSG') +HASH_REGEX = re.compile(r'MSG') MASTER_BRANCH = 'MSG' +ENV_TORCH_HOME = 'MSG' +ENV_XDG_CACHE_HOME = 'MSG' +DEFAULT_CACHE_DIR = 'MSG' +VAR_DEPENDENCY = 'MSG' +MODULE_HUBCONF = 'MSG' +READ_DATA_CHUNK = 8192 +_hub_dir = None def import_module(name, path): + import importlib.util + from importlib.abc import Loader + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + assert isinstance(spec.loader, Loader) + spec.loader.exec_module(module) + return module +def _remove_if_exists(path): + if os.path.exists(path): + if os.path.isfile(path): + os.remove(path) + else: + shutil.rmtree(path) +def _git_archive_link(repo_owner, repo_name, branch): + return 'MSG'.format(repo_owner, repo_name, branch) +def _load_attr_from_module(module, func_name): + + if func_name not in dir(module): + return None + return getattr(module, func_name) +def _get_torch_home(): + torch_home = os.path.expanduser( + os.getenv(ENV_TORCH_HOME, + os.path.join(os.getenv(ENV_XDG_CACHE_HOME, + DEFAULT_CACHE_DIR), 'MSG'))) + return torch_home +def _parse_repo_info(github): + branch = MASTER_BRANCH + if 'MSG' in github: + repo_info, branch = github.split('MSG') + else: + repo_info = github + repo_owner, repo_name = repo_info.split('MSG') + return repo_owner, repo_name, branch +def _get_cache_or_reload(github, force_reload, verbose=True): + + hub_dir = get_dir() + if not os.path.exists(hub_dir): + os.makedirs(hub_dir) + + repo_owner, repo_name, branch = _parse_repo_info(github) + + + + + normalized_br = branch.replace('MSG', 'MSG') + + + + + repo_dir = os.path.join(hub_dir, 'MSG'.join([repo_owner, repo_name, normalized_br])) use_cache = (not force_reload) and os.path.exists(repo_dir) if use_cache: + if verbose: + sys.stderr.write('MSG'.format(repo_dir)) + else: + cached_file = os.path.join(hub_dir, normalized_br + 'MSG') + _remove_if_exists(cached_file) url = _git_archive_link(repo_owner, repo_name, branch) + sys.stderr.write('MSG'.format(url, cached_file)) + download_url_to_file(url, cached_file, progress=False) with zipfile.ZipFile(cached_file) as cached_zipfile: + extraced_repo_name = cached_zipfile.infolist()[0].filename + extracted_repo = os.path.join(hub_dir, extraced_repo_name) + _remove_if_exists(extracted_repo) + cached_zipfile.extractall(hub_dir) _remove_if_exists(cached_file) + _remove_if_exists(repo_dir) + shutil.move(extracted_repo, repo_dir) return repo_dir +def _check_module_exists(name): + if sys.version_info >= (3, 4): + import importlib.util + return importlib.util.find_spec(name) is not None + elif sys.version_info >= (3, 3): + import importlib.find_loader + return importlib.find_loader(name) is not None + else: + try: + import imp + result = imp.find_module(name) + if result: + return True + except ImportError: + pass + path = sys.path + for item in path: + importer = sys.path_importer_cache.get(item) + if importer: + try: + result = importer.find_module(name, [item]) + if result: + return True + except ImportError: + pass + return False def _check_dependencies(m): + dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) if dependencies is not None: + missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] + if len(missing_deps): + raise RuntimeError('MSG'.format('MSG'.join(missing_deps))) +def _load_entry_from_hubconf(m, model): + if not isinstance(model, str): + raise ValueError('MSG') + + + + _check_dependencies(m) func = _load_attr_from_module(m, model) if func is None or not callable(func): + raise RuntimeError('MSG'.format(model)) return func +def get_dir(): + r + + if os.getenv('MSG'): + warnings.warn('MSG') if _hub_dir is not None: + return _hub_dir + return os.path.join(_get_torch_home(), 'MSG') +def set_dir(d): + r + global _hub_dir + _hub_dir = d +def list(github, force_reload=False): + r + repo_dir = _get_cache_or_reload(github, force_reload, True) sys.path.insert(0, repo_dir) hub_module = import_module(MODULE_HUBCONF, repo_dir + 'MSG' + MODULE_HUBCONF) sys.path.remove(repo_dir) + entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('MSG')] return entrypoints +def help(github, model, force_reload=False): + r + repo_dir = _get_cache_or_reload(github, force_reload, True) sys.path.insert(0, repo_dir) hub_module = import_module(MODULE_HUBCONF, repo_dir + 'MSG' + MODULE_HUBCONF) sys.path.remove(repo_dir) entry = _load_entry_from_hubconf(hub_module, model) return entry.__doc__ +def load(repo_or_dir, model, *args, **kwargs): + r + source = kwargs.pop('MSG', 'MSG').lower() + force_reload = kwargs.pop('MSG', False) + verbose = kwargs.pop('MSG', True) if source not in ('MSG', 'MSG'): + raise ValueError( + f'MSG') if source == 'MSG': + repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose) model = _load_local(repo_or_dir, model, *args, **kwargs) + return model +def _load_local(hubconf_dir, model, *args, **kwargs): + r + sys.path.insert(0, hubconf_dir) hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF) + hub_module = import_module(MODULE_HUBCONF, hubconf_path) entry = _load_entry_from_hubconf(hub_module, model) + model = entry(*args, **kwargs) sys.path.remove(hubconf_dir) return model +def download_url_to_file(url, dst, hash_prefix=None, progress=True): + r + file_size = None + + + req = Request(url, headers={"MSG": "MSG"}) + u = urlopen(req) + meta = u.info() + if hasattr(meta, 'MSG'): + content_length = meta.getheaders("MSG") + else: + content_length = meta.get_all("MSG") + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + + + dst = os.path.expanduser(dst) + dst_dir = os.path.dirname(dst) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) try: + if hash_prefix is not None: + sha256 = hashlib.sha256() + with tqdm(total=file_size, disable=not progress, + unit='MSG', unit_scale=True, unit_divisor=1024) as pbar: + while True: + buffer = u.read(8192) + if len(buffer) == 0: + break + f.write(buffer) + if hash_prefix is not None: + sha256.update(buffer) + pbar.update(len(buffer)) f.close() + if hash_prefix is not None: + digest = sha256.hexdigest() + if digest[:len(hash_prefix)] != hash_prefix: + raise RuntimeError('MSG' + .format(hash_prefix, digest)) + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) def _download_url_to_file(url, dst, hash_prefix=None, progress=True): + warnings.warn('MSG') + download_url_to_file(url, dst, hash_prefix, progress) +def _is_legacy_zip_format(filename): + if zipfile.is_zipfile(filename): + infolist = zipfile.ZipFile(filename).infolist() + return len(infolist) == 1 and not infolist[0].is_dir() + return False def _legacy_zip_load(filename, model_dir, map_location): + warnings.warn('MSG' + 'MSG' + 'MSG') + + + + with zipfile.ZipFile(filename) as f: + members = f.infolist() + if len(members) != 1: + raise RuntimeError('MSG') + f.extractall(model_dir) + extraced_name = members[0].filename + extracted_file = os.path.join(model_dir, extraced_name) + return torch.load(extracted_file, map_location=map_location) def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): + r + + if os.getenv('MSG'): + warnings.warn('MSG') if model_dir is None: + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'MSG') try: + os.makedirs(model_dir) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('MSG'.format(url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) if _is_legacy_zip_format(cached_file): + return _legacy_zip_load(cached_file, model_dir, map_location) + return torch.load(cached_file, map_location=map_location) +import __future__ import collections +import functools +import types +from typing import Dict, Set, List, Any, Callable, Iterable import torch +from torch._C import _is_torch_function_enabled, _disabled_torch_function_impl @functools.lru_cache(None) +def get_ignored_functions() -> Set[Callable]: + + Tensor = torch.Tensor + return { + torch.typename, + torch.is_tensor, + torch.is_storage, + torch.set_default_tensor_type, + torch.set_rng_state, + torch.get_rng_state, + torch.manual_seed, + torch.initial_seed, + torch.seed, + torch.save, + torch.load, + torch.set_printoptions, + torch.fork, + torch.get_default_dtype, + torch.get_num_interop_threads, + torch.get_num_threads, + torch.init_num_threads, + torch.import_ir_module, + torch.import_ir_module_from_buffer, + torch.is_anomaly_enabled, + torch.is_grad_enabled, + torch.merge_type_from_type_comment, + torch.parse_ir, + torch.parse_schema, + torch.parse_type_comment, + torch.set_anomaly_enabled, + torch.set_flush_denormal, + torch.set_num_interop_threads, + torch.set_num_threads, + torch.wait, + torch.as_tensor, + torch.from_numpy, + torch.get_device, + torch.tensor, + torch.default_generator, + torch.has_cuda, + torch.has_cudnn, + torch.has_lapack, + torch.cpp, + torch.device, + torch.dtype, + torch.finfo, + torch.has_mkl, + torch.has_mkldnn, + torch.has_openmp, + torch.iinfo, + torch.memory_format, + torch.qscheme, + torch.set_grad_enabled, + torch.no_grad, + torch.enable_grad, + torch.layout, + torch.align_tensors, + torch.arange, + torch.as_strided, + torch.bartlett_window, + torch.blackman_window, + torch.can_cast, + torch.cudnn_affine_grid_generator, + torch.cudnn_batch_norm, + torch.cudnn_convolution, + torch.cudnn_convolution_transpose, + torch.cudnn_grid_sampler, + torch.cudnn_is_acceptable, + torch.empty, + torch.empty_meta, + torch.empty_strided, + torch.empty_quantized, + torch.eye, + torch.from_file, + torch.full, + torch.hamming_window, + torch.hann_window, + torch.kaiser_window, + torch.linspace, + torch.logspace, + torch.mkldnn_adaptive_avg_pool2d, + torch.mkldnn_convolution, + torch.mkldnn_convolution_backward_weights, + torch.mkldnn_max_pool2d, + torch.mkldnn_max_pool3d, + torch.normal, + torch.ones, + torch.promote_types, + torch.rand, + torch.randn, + torch.randint, + torch.randperm, + torch.range, + torch.result_type, + torch.scalar_tensor, + torch.sparse_coo_tensor, + torch.tril_indices, + torch.triu_indices, + torch.vander, + torch.zeros, + torch.nn.functional.assert_int_or_pair, + torch.nn.functional.boolean_dispatch, + torch.nn.functional.upsample, + torch.nn.functional.upsample_bilinear, + torch.nn.functional.upsample_nearest, + torch.nn.functional.has_torch_function, + torch.nn.functional.handle_torch_function, + torch.nn.functional.sigmoid, + torch.nn.functional.hardsigmoid, + torch.nn.functional.tanh, + torch.set_autocast_enabled, + torch.is_autocast_enabled, + torch.clear_autocast_cache, + torch.autocast_increment_nesting, + torch.autocast_decrement_nesting, + torch.nn.functional.hardswish, + torch.is_vulkan_available, + torch.is_deterministic, + torch.set_deterministic, + torch.unify_type_list, + Tensor.__delitem__, + Tensor.__dir__, + Tensor.__getattribute__, + Tensor.__init__, + Tensor.__init_subclass__, + Tensor.__delattr__, + Tensor.__setattr__, + Tensor.__torch_function__, + Tensor.__new__, + Tensor.__class__, + Tensor.__subclasshook__, + Tensor.as_subclass, + Tensor.reinforce, + Tensor.new, + Tensor.new_tensor, + Tensor.new_empty, + Tensor.new_zeros, + Tensor.new_ones, + Tensor.new_full, + Tensor._make_subclass, + Tensor.stride, + Tensor.unflatten, + } +@functools.lru_cache(None) +def get_testing_overrides() -> Dict[Callable, Callable]: + + + + + + + + + Tensor = torch.Tensor + ret = { + torch.abs: lambda input, out=None: -1, + torch.absolute: lambda input, out=None: -1, + torch.adaptive_avg_pool1d: lambda input, output_size: -1, + torch.adaptive_max_pool1d: lambda inputs, output_size: -1, + torch.acos: lambda input, out=None: -1, + torch.arccos: lambda input, out=None: -1, + torch.acosh: lambda input, out=None: -1, + torch.arccosh: lambda input, out=None: -1, + torch.add: lambda input, other, out=None: -1, + torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1, + torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1, + torch.addcmul: lambda input, tensor1, tensor2, value=1, out=None: -1, + torch.addmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1, + torch.addmv: lambda input, mat, vec, beta=1, alpha=1, out=None: -1, + torch.addr: lambda input, vec1, vec2, beta=1, alpha=1, out=None: -1, + torch.affine_grid_generator: lambda theta, size, align_corners: -1, + torch.all: lambda input, dim=None: -1, + torch.allclose: lambda input, other, trol=1e-05, atol=1e-08, equal_nan=False: -1, + torch.alpha_dropout: lambda input, p, train, inplace=False: -1, + torch.amax: lambda input, dim=None: -1, + torch.amin: lambda input, dim=None: -1, + torch.angle: lambda input, out=None: -1, + torch.any: lambda input, dim=None, keepdim=False, out=None: -1, + torch.argmax: lambda input: -1, + torch.argmin: lambda input: -1, + torch.argsort: lambda input, dim=None: -1, + torch.asin: lambda input, out=None: -1, + torch.arcsin: lambda input, out=None: -1, + torch.asinh: lambda input, out=None: -1, + torch.arcsinh: lambda input, out=None: -1, + torch.atan: lambda input, out=None: -1, + torch.arctan: lambda input, out=None: -1, + torch.atan2: lambda input, other, out=None: -1, + torch.atanh: lambda input, out=None: -1, + torch.arctanh: lambda input, out=None: -1, + torch.atleast_1d: lambda *tensors: -1, + torch.atleast_2d: lambda *tensors: -1, + torch.atleast_3d: lambda *tensors: -1, + torch.avg_pool1d: lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True: -1, + torch.baddbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1, + torch.batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled: -1, + torch.batch_norm_backward_elemt: lambda grad_out, input, mean, invstd, weight, mean_dy, mean_dy_xmu: -1, + torch.batch_norm_backward_reduce: lambda grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g: -1, + torch.batch_norm_elemt: lambda input, weight, bias, mean, invstd, eps: -1, + torch.batch_norm_gather_stats: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1, + torch.batch_norm_gather_stats_with_counts: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1, + torch.batch_norm_stats: lambda input, eps: -1, + torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1, + torch.bernoulli: lambda input, generator=None, out=None: -1, + torch.bilinear: lambda input1, input2, weight, bias: -1, + torch.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None, reduce=None, + reduction='MSG', pos_weight=None: -1), + torch.bincount: lambda input, weights=None, minlength=0: -1, + torch.binomial: lambda count, prob, generator=None: -1, + torch.bitwise_and: lambda input, other, out=None: -1, + torch.bitwise_not: lambda input, out=None: -1, + torch.bitwise_or: lambda input, other, out=None: -1, + torch.bitwise_xor: lambda input, other, out=None: -1, + torch.block_diag: lambda *tensors: -1, + torch.bmm: lambda input, mat2, out=None: -1, + torch.broadcast_tensors: lambda *tensors: -1, + torch.bucketize: lambda input, boundaries, out_int32=False, right=False, out=None: -1, + torch.cartesian_prod: lambda *tensors: -1, + torch.cat: lambda tensors, dim=0, out=None: -1, + torch.cdist: lambda x1, x2, p=2.0, compute_mode='MSG': -1, + torch.ceil: lambda input, out=None: -1, + torch.celu: lambda input, alhpa=1., inplace=False: -1, + torch.chain_matmul: lambda *matrices: -1, + torch.channel_shuffle: lambda input, groups : -1, + torch.cholesky: lambda input, upper=False, out=None: -1, + torch.cholesky_inverse: lambda input, upper=False, out=None: -1, + torch.cholesky_solve: lambda input1, input2, upper=False, out=None: -1, + torch.choose_qparams_optimized: lambda input, numel, n_bins, ratio, bit_width: -1, + torch.chunk: lambda input, chunks, dim=0: -1, + torch.clamp: lambda input, min=None, max=None, out=None: -1, + torch.clip: lambda input, min=None, max=None, out=None: -1, + torch.clamp_min: lambda input, min, out=None: -1, + torch.clamp_max: lambda input, max, out=None: -1, + torch.clone: lambda input: -1, + torch.combinations: lambda input, r=2, with_replacement=False: -1, + torch.complex: lambda real, imag: -1, + torch.polar: lambda abs, ang: -1, + torch.conj: lambda input, out=None: -1, + torch.constant_pad_nd: lambda input, pad, value=0: -1, + torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, + torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, + torch.conv3d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, + torch.convolution: lambda input, weight, bias, stride, padding, dilation, transposed, output_adding, groups: -1, + torch.conv_tbc: lambda input, weight, bias, pad=0: -1, + torch.conv_transpose1d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1, + torch.conv_transpose2d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1, + torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1, + torch.cos: lambda input, out=None: -1, + torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='MSG': -1, + torch.cosh: lambda input, out=None: -1, + torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1, + torch.count_nonzero: lambda input: -1, + torch.cross: lambda input, other, dim=-1, out=None: -1, + torch.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='MSG', + zero_infinity=False: -1), + torch.cummax: lambda input, dim, out=None: -1, + torch.cummin: lambda input, dim, out=None: -1, + torch.cumprod: lambda input, dim, out=None, dtype=None: -1, + torch.cumsum: lambda input, dim, out=None, dtype=None: -1, + torch.logcumsumexp: lambda input, dim, out=None: -1, + torch.deg2rad: lambda input, out=None: -1, + torch.dequantize: lambda input: -1, + torch.det: lambda input: -1, + torch.linalg.det: lambda input: -1, + torch.detach: lambda input: -1, + torch.diag: lambda input, diagonal=0, out=None: -1, + torch.diag_embed: lambda input, diagonal=0, out=None: -1, + torch.diagflat: lambda input, offset=0: -1, + torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1, + torch.digamma: lambda input, out=None: -1, + torch.dist: lambda input, other, p=2: -1, + torch.div: lambda input, other, out=None: -1, + torch.divide: lambda input, other, out=None: -1, + torch.dot: lambda mat1, mat2: -1, + torch.dropout: lambda input, p, train, inplace=False: -1, + torch.dsmm: lambda input, mat2: -1, + torch.hsmm: lambda mat1, mat2: -1, + torch.dstack: lambda tensors, out=None: -1, + torch.eig: lambda input, eigenvectors=False, out=None: -1, + torch.einsum: lambda equation, *operands: -1, + torch.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, + sparse=False: -1), + torch.embedding_bag: (lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, + mode='MSG', sparse=False, per_sample_weights=None: -1), + torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, + torch.eq: lambda input, other, out=None: -1, + torch.equal: lambda input, other: -1, + torch.erf: lambda input, out=None: -1, + torch.erfc: lambda input, out=None: -1, + torch.erfinv: lambda input, out=None: -1, + torch.exp: lambda input, out=None: -1, + torch.exp2: lambda input, out=None: -1, + torch.expm1: lambda input, out=None: -1, + torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1, + torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1, + torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1, + torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1, + torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, + torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale, + weight_zero_point, bias: -1), + torch.fbgemm_linear_quantize_weight: lambda input: -1, + torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1, + torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1, + torch.feature_alpha_dropout: lambda input, p, train: -1, + torch.feature_dropout: lambda input, p, train: -1, + torch.fix: lambda input, out=None: -1, + torch.fft: lambda input, signal_ndim, normalized=False: -1, + torch.flatten: lambda input, start_dim=0, end_dim=-1: -1, + torch.flip: lambda input, dims: -1, + torch.fliplr: lambda input: -1, + torch.flipud: lambda input: -1, + torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1, + torch.floor: lambda input, out=None: -1, + torch.floor_divide: lambda input, other: -1, + torch.fmod: lambda input, other, out=None: -1, + torch.frac: lambda input, out=None: -1, + torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, + torch.functional.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1, + torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1, + torch.gcd: lambda input, other, out=None: -1, + torch.ge: lambda input, other, out=None: -1, + torch.greater_equal: lambda input, other, out=None: -1, + torch.geqrf: lambda input, out=None: -1, + torch.i0: lambda input, out=None: -1, + torch.outer: lambda input, vec2, out=None: -1, + torch.ger: lambda input, vec2, out=None: -1, + torch.grid_sampler: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1, + torch.grid_sampler_2d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1, + torch.grid_sampler_3d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1, + torch.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05, cudnn_enabled=True: -1, + torch.gru: lambda input, hx, params, has_biases, num_layers, gropout, train, bidirectional, batch_first: -1, + torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, + torch.gt: lambda input, other, out=None: -1, + torch.greater: lambda input, other, out=None: -1, + torch.hardshrink: lambda input, lambd=0.5: -1, + torch.heaviside: lambda input, values, out=None: -1, + torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='MSG': -1, + torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1, + torch.hspmm: lambda mat1, mat2, out=None: -1, + torch.hstack: lambda tensors, out=None: -1, + torch.hypot: lambda input, other, out=None: -1, + torch.ifft: lambda input, signal_ndim, normalized=False: -1, + torch.imag: lambda input, out=None: -1, + torch.index_add: lambda input, dim, index, source: -1, + torch.index_copy: lambda input, dim, index, source: -1, + torch.index_put: lambda input, indices, values, accumulate=False: -1, + torch.index_select: lambda input, dim, index, out=None: -1, + torch.index_fill: lambda input, dim, index, value: -1, + torch.isfinite: lambda tensor: -1, + torch.isinf: lambda tensor: -1, + torch.isreal: lambda tensor: -1, + torch.isposinf: lambda input, out=None: -1, + torch.isneginf: lambda input, out=None: -1, + torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, + cudnn_enabled: -1), + torch.int_repr: lambda input: -1, + torch.inverse: lambda input, out=None: -1, + torch.irfft: lambda input, signal_ndim, normalized=False, onesided=True, signal_sizes=None: -1, + torch.is_complex: lambda input: -1, + torch.is_distributed: lambda input: -1, + torch.is_floating_point: lambda input: -1, + torch.is_nonzero: lambda input: -1, + torch.is_same_size: lambda input, other: -1, + torch.is_signed: lambda input: -1, + torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1, + torch.isnan: lambda input: -1, + torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, + normalized=False, onesided=None, length=None, return_complex=False: -1), + torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='MSG', log_target=False: -1, + torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1, + torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1, + torch.lcm: lambda input, other, out=None: -1, + torch.le: lambda input, other, out=None: -1, + torch.less_equal: lambda input, other, out=None: -1, + torch.lerp: lambda input, end, weight, out=None: -1, + torch.lgamma: lambda input, out=None: -1, + torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, + tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1, + torch.log: lambda input, out=None: -1, + torch.log_softmax: lambda input, dim, dtype=None: -1, + torch.log10: lambda input, out=None: -1, + torch.log1p: lambda input, out=None: -1, + torch.log2: lambda input, out=None: -1, + torch.logaddexp: lambda input, other, out=None: -1, + torch.logaddexp2: lambda input, other, out=None: -1, + torch.logdet: lambda input: -1, + torch.logical_and: lambda input, other, out=None: -1, + torch.logical_not: lambda input, out=None: -1, + torch.logical_or: lambda input, other, out=None: -1, + torch.logical_xor: lambda input, other, out=None: -1, + torch.logsumexp: lambda input, names, keepdim=False, out=None: -1, + torch.logit: lambda input, eps=None: -1, + torch.logsumexp: lambda input, names, keepdim=False, out=None: -1, + torch.lstm: lambda data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional: -1, + torch.lstm_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, + torch.lstsq: lambda input, A, out=None: -1, + torch.lt: lambda input, other, out=None: -1, + torch.less: lambda input, other, out=None: -1, + torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1, + torch.lu_solve: lambda input, LU_data, LU_pivots, out=None: -1, + torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='MSG': -1, + torch.masked_fill: lambda input, mask, value: -1, + torch.masked_scatter: lambda input, mask, source: -1, + torch.masked_select: lambda input, mask, out=None: -1, + torch.matmul: lambda input, other, out=None: -1, + torch.matrix_power: lambda input, n: -1, + torch.matrix_rank: lambda input, tol=None, symmetric=False: -1, + torch.matrix_exp: lambda input: -1, + torch.max: lambda input, out=None: -1, + torch.maximum: lambda input, other, out=None: -1, + torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, + torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, + torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, + torch.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1, + return_indices=False, ceil_mode=False: -1), + torch.mean: lambda input, dim=None: -1, + torch.median: lambda input, dim=None: -1, + torch.meshgrid: lambda *tensors, **kwargs: -1, + torch.min: lambda input, out=None: -1, + torch.minimum: lambda input, other, out=None: -1, + torch.miopen_batch_norm: (lambda input, weight, bias, running_mean, running_var, training, + exponential_average_factor, epsilon: -1), + torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1, + torch.miopen_convolution_transpose: (lambda input, weight, bias, padding, output_padding, stride, dilation, + groups, benchmark, deterministic: -1), + torch.miopen_depthwise_convolution: (lambda input, weight, bias, padding, stride, dilation, groups, benchmark, + deterministic: -1), + torch.miopen_rnn: (lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, + dropout, train, bidirectional, batch_sizes, dropout_state: -1), + torch.mm: lambda input, mat2, out=None: -1, + torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1, + torch.movedim: lambda input, source, destination: -1, + torch.mul: lambda input, other, out=None: -1, + torch.multiply: lambda input, other, out=None: -1, + torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1, + torch.mv: lambda input, vec, out=None: -1, + torch.mvlgamma: lambda input, p: -1, + torch.narrow: lambda input, dim, start, length: -1, + torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1, + torch.native_layer_norm: lambda input, weight, bias, M, N, eps: -1, + torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1, + torch.native_norm: lambda input, p=2: -1, + torch.native_norm: lambda input, p=2: -1, + torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1, + torch.ne: lambda input, other, out=None: -1, + torch.not_equal: lambda input, other, out=None: -1, + torch.neg: lambda input, out=None: -1, + torch.negative: lambda input, out=None: -1, + torch.nextafter: lambda input, other, out=None: -1, + torch.nn.functional.adaptive_avg_pool2d: lambda input, output_size: -1, + torch.nn.functional.adaptive_avg_pool3d: lambda input, output_size: -1, + torch.nn.functional.adaptive_max_pool1d: lambda input, output_size, return_indices=False: -1, + torch.nn.functional.adaptive_max_pool1d_with_indices: lambda input, output_size, return_indices=False: -1, + torch.nn.functional.adaptive_max_pool2d: lambda input, output_size, return_indices=False: -1, + torch.nn.functional.adaptive_max_pool2d_with_indices: lambda input, output_size, return_indices=False: -1, + torch.nn.functional.adaptive_max_pool3d: lambda input, output_size, return_indices=False: -1, + torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1, + torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1, + torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1, + torch.nn.functional.avg_pool2d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, + count_include_pad=True, divisor_override=None: -1), + torch.nn.functional.avg_pool3d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, + count_include_pad=True, divisor_override=None: -1), + torch.nn.functional.batch_norm: (lambda input, running_mean, running_var, weight=None, bias=None, training=False, + momentum=0.1, eps=1e-05: -1), + torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1, + torch.nn.functional.binary_cross_entropy: (lambda input, target, weight=None, size_average=None, reduce=None, + reduction="MSG": -1), + torch.nn.functional.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None, + reduce=None, reduction="MSG", pos_weight=None: -1), + torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1, + torch.nn.functional.cosine_embedding_loss: (lambda input1, input2, target, margin=0, size_average=None, + reduce=None, reduction='MSG': -1), + torch.nn.functional.cross_entropy: (lambda input, target, weight=None, size_average=None, ignore_index=-100, + reduce=None, reduction="MSG": -1), + torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, + reduction='MSG', zero_infinity=False: -1), + torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1, + torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1, + torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1, + torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1, + torch.nn.functional.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, + scale_grad_by_freq=False, sparse=False: -1), + torch.nn.functional.embedding_bag: (lambda input, weight, offsets=None, max_norm=None, norm_type=2, + scale_grad_by_freq=False, mode='MSG', sparse=False, per_sample_weights=None, + include_last_offset=False: -1), + torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1, + torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1, + torch.nn.functional.fractional_max_pool2d: (lambda input, kernel_size, output_size=None, output_ratio=None, + return_indices=False, _random_samples=None: -1), + torch.nn.functional.fractional_max_pool2d_with_indices: ( + lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, + _random_samples=None: -1), + torch.nn.functional.fractional_max_pool3d: (lambda input, kernel_size, output_size=None, output_ratio=None, + return_indices=False, _random_samples=None: -1), + torch.nn.functional.fractional_max_pool3d_with_indices: ( + lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, + _random_samples=None: -1), + torch.nn.functional.gelu: lambda input: -1, + torch.nn.functional.glu: lambda input, dim=-1: -1, + torch.nn.functional.grid_sample: lambda input, grid, mode='MSG', padding_mode='MSG', align_corners=None: -1, + torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1, + torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1, + torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1, + torch.nn.functional.hardtanh: lambda input, min_val=-1., max_val=1., inplace=False: -1, + torch.nn.functional.hinge_embedding_loss: (lambda input, target, margin=1.0, size_average=None, reduce=None, + reduction='MSG': -1), + torch.nn.functional.instance_norm: (lambda input, running_mean=None, running_var=None, weight=None, bias=None, + use_input_stats=True, momentum=0.1, eps=1e-05: -1), + torch.nn.functional.interpolate: (lambda input, size=None, scale_factor=None, mode='MSG', align_corners=None, + recompute_scale_factor=None: -1), + torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction='MSG', log_target=False: -1, + torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction='MSG': -1, + torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, + torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1, + torch.nn.functional.linear: lambda input, weight, bias=None: -1, + torch.nn.functional.local_response_norm: lambda input, size, alpha=0.0001, beta=0.75, k=1.0: -1, + torch.nn.functional.log_softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1, + torch.nn.functional.logsigmoid: lambda input: -1, + torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1, + torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1, + torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None, + reduce=None, reduction='MSG': -1), + torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1, + ceil_mode=False, return_indices=False: -1), + torch.nn.functional.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1, + return_indices=False, ceil_mode=False: -1), + torch.nn.functional.max_pool2d: (lambda input, kernel_size, stride=None, padding=0, dilation=1, + ceil_mode=False, return_indices=False: -1), + torch.nn.functional.max_pool2d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1, + return_indices=False, ceil_mode=False: -1), + torch.nn.functional.max_pool3d: (lambda input, kernel_size, stride=None, padding=0, dilation=1, + return_indices=False, ceil_mode=False: -1), + torch.nn.functional.max_pool3d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1, + return_indices=False, ceil_mode=False: -1), + torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, + torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, + torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, + torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction='MSG': -1, + torch.nn.functional.multi_head_attention_forward: ( + lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, + add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, + need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, + v_proj_weight=None, static_k=None, static_v=None: -1), + torch.nn.functional.multi_margin_loss: (lambda input, target, p=1, margin=1.0, weight=None, size_average=None, + reduce=None, reduction='MSG': -1), + torch.nn.functional.multilabel_margin_loss: (lambda input, target, size_average=None, reduce=None, + reduction='MSG': -1), + torch.nn.functional.multilabel_soft_margin_loss: (lambda input, target, weight=None, size_average=None, + reduce=None, reduction='MSG': -1), + torch.nn.functional.nll_loss: (lambda input, target, weight=None, size_average=None, ignore_index=-100, + reduce=None, reduction='MSG': -1), + torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1, + torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1, + torch.nn.functional.pad: lambda input, pad, mode='MSG', value=0: -1, + torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1, + torch.nn.functional.poisson_nll_loss: (lambda input, target, log_input=True, full=False, size_average=None, + eps=1e-08, reduce=None, reduction='MSG': -1), + torch.nn.functional.prelu: lambda input, weight: -1, + torch.nn.functional.relu: lambda input, inplace=False: -1, + torch.nn.functional.relu6: lambda input, inplace=False: -1, + torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, + torch.nn.functional.selu: lambda input, inplace=False: -1, + torch.nn.functional.silu: lambda input, inplace=False: -1, + torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='MSG', beta=1.: -1, + torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='MSG': -1, + torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1, + torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1, + torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1, + torch.nn.functional.softshrink: lambda input, lambd=0.5: -1, + torch.nn.functional.softsign: lambda input: -1, + torch.nn.functional.tanhshrink: lambda input: -1, + torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1, + torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, + swap=False, size_average=None, reduce=None, reduction='MSG': -1), + torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *, + distance_function=None, margin=1.0, + swap=False, reduction='MSG': -1), + torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1, + torch.nonzero: lambda input, as_tuple=False: -1, + torch.norm: lambda input, p='MSG', dim=None, keepdim=False, out=None, dtype=None: -1, + torch.norm_except_dim: lambda v, pow=2, dim=0: -1, + torch.nuclear_norm: lambda input, p='MSG', dim=None, keepdim=False, out=None, dtype=None: -1, + torch.numel: lambda input: -1, + torch.orgqr: lambda input1, input2: -1, + torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1, + torch.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1, + torch.pca_lowrank: lambda input, q=None, center=True, niter=2: -1, + torch.pdist: lambda input, p=2: -1, + torch.pinverse: lambda input, rcond=1e-15: -1, + torch.pixel_shuffle: lambda input, upscale_factor: -1, + torch.poisson: lambda input, generator=None: -1, + torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1, + torch.polygamma: lambda input, n, out=None: -1, + torch.prelu: lambda input, weight: -1, + torch.ones_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, + torch.pow: lambda input, exponent, out=None: -1, + torch.prod: lambda input, dtype=None: -1, + torch.q_per_channel_axis: lambda input: -1, + torch.q_per_channel_scales: lambda input: -1, + torch.q_per_channel_zero_points: lambda input: -1, + torch.q_scale: lambda input: -1, + torch.q_zero_point: lambda input: -1, + torch.qr: lambda input, some=True, out=None: -1, + torch.quantile: lambda input, q, dim=None, keepdim=False, out=None: -1, + torch.nanquantile: lambda input, q, dim=None, keepdim=False, out=None: -1, + torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1, + torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1, + torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1, + torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, + col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, + col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), + torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,), + dilation=(1,), ceil_mode=False: -1), + torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0), + dilation=(1, 1), ceil_mode=False: -1), + torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, + col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), + torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, + col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), + torch.rad2deg: lambda input, out=None: -1, + torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, + torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, + torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, + torch.real: lambda input, out=None: -1, + torch.vdot: lambda mat1, mat2: -1, + torch.view_as_real: lambda input: -1, + torch.view_as_complex: lambda input: -1, + torch.reciprocal: lambda input, out=None: -1, + torch.relu: lambda input, inplace=False: -1, + torch.remainder: lambda input, other, out=None: -1, + torch.renorm: lambda input, p, dim, maxnorm, out=None: -1, + torch.repeat_interleave: lambda input, dim=None: -1, + torch.reshape: lambda input, shape: -1, + torch.rfft: lambda input, signal_ndim, normalized=False, onesided=True: -1, + torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, + torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, + torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, + torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, + torch.roll: lambda input, shifts, dims=None: -1, + torch.rot90: lambda input, k=1, dims=(0, 1): -1, + torch.round: lambda input, out=None: -1, + torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1, + torch.rsqrt: lambda input, out=None: -1, + torch.rsub: lambda input, other, alpha=1: -1, + torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1, + torch.scatter: lambda input, dim, index, src: -1, + torch.scatter_add: lambda input, dim, index, src: -1, + torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1, + torch.select: lambda input, dim, index: -1, + torch.selu: lambda input, inplace=False: -1, + torch.sigmoid: lambda input, out=None: -1, + torch.sign: lambda input, out=None: -1, + torch.signbit: lambda input, out=None: -1, + torch.sgn: lambda input, out=None: -1, + torch.sin: lambda input, out=None: -1, + torch.sinh: lambda input, out=None: -1, + torch.slogdet: lambda input: -1, + torch.smm: lambda input, mat2: -1, + torch.spmm: lambda input, mat2: -1, + torch.softmax: lambda input, dim, dtype=None: -1, + torch.solve: lambda input, A, out=None: -1, + torch.sort: lambda input, dim=-1, descending=False, out=None: -1, + torch.split: lambda tensor, split_size_or_sections, dim=0: -1, + torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1, + torch.sqrt: lambda input, out=None: -1, + torch.square: lambda input, out=None: -1, + torch.squeeze: lambda input, dim=None, out=None: -1, + torch.sspaddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1, + torch.stack: lambda tensors, dim=0, out=None: -1, + torch.std: lambda input, dim=None: -1, + torch.std_mean: lambda input, dim=None: -1, + torch.stft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, + pad_mode='MSG', normalized=False, onesided=True, return_complex=None: -1), + torch.sub: lambda input, other, out=None: -1, + torch.subtract: lambda input, other, out=None: -1, + torch.sum: lambda input, dim=None: -1, + torch.nansum: lambda input, dim=None: -1, + torch.svd: lambda input, some=True, compute_uv=True, out=None: -1, + torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1, + torch.symeig: lambda input, eigenvectors=False, upper=True, out=None: -1, + torch.t: lambda input: -1, + torch.take: lambda input, index: -1, + torch.tan: lambda input, out=None: -1, + torch.tanh: lambda input, out=None: -1, + torch.tensordot: lambda a, b, dims=2: -1, + torch.threshold: lambda input, threshold, value, inplace=False: -1, + torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1, + torch.trace: lambda input: -1, + torch.transpose: lambda input, dim0, dim1: -1, + torch.trapz: lambda y, x=None, dim=-1: -1, + torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1, + torch.tril: lambda input, diagonal=0, out=None: -1, + torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, + size_average=None, reduce=None, reduction='MSG': -1), + torch.triu: lambda input, diagonal=0, out=None: -1, + torch.true_divide: lambda input, other: -1, + torch.trunc: lambda input, out=None: -1, + torch.unbind: lambda input, dim=0: -1, + torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1, + torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1, + torch.unsafe_chunk: lambda input, chunks, dim=0: -1, + torch.unsafe_split: lambda tensor, split_size_or_sections, dim=0: -1, + torch.unsafe_split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1, + torch.unsqueeze: lambda input, dim, out=None: -1, + torch.var: lambda input, dim=None: -1, + torch.var_mean: lambda input, dim=None: -1, + torch.vstack: lambda tensors, out=None: -1, + torch.where: lambda condition, x=None, y=None: -1, + torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, + Tensor.__floordiv__: lambda self, other: -1, + Tensor.__rfloordiv__: lambda self, other: -1, + Tensor.__ifloordiv__: lambda self, other: -1, + Tensor.__truediv__: lambda self, other: -1, + Tensor.__rtruediv__: lambda self, other: -1, + Tensor.__itruediv__: lambda self, other: -1, + Tensor.__lshift__: lambda self, other: -1, + Tensor.__ilshift__: lambda self, other: -1, + Tensor.__rshift__: lambda self, other: -1, + Tensor.__irshift__: lambda self, other: -1, + Tensor.__float__: lambda self: -1, + Tensor.__complex__: lambda self: -1, + Tensor.__array__: lambda self, dtype: -1, + Tensor.__bool__: lambda self: -1, + Tensor.__contains__: lambda self, other: -1, + Tensor.__neg__: lambda self: -1, + Tensor.__invert__: lambda self: -1, + Tensor.__mod__: lambda self, other: -1, + Tensor.__array_wrap__: lambda self, array: -1, + Tensor.__getitem__: lambda self, idx: -1, + Tensor.__deepcopy__: lambda self, memo: -1, + Tensor.__iter__: lambda self: -1, + Tensor.__int__: lambda self: -1, + Tensor.__long__: lambda self: -1, + Tensor.__hash__: lambda self: -1, + Tensor.__index__: lambda self: -1, + Tensor.__len__: lambda self: -1, + Tensor.__format__: lambda self, format_spec: -1, + Tensor.__reduce_ex__: lambda self, proto: -1, + Tensor.__reversed__: lambda self: -1, + Tensor.__repr__: lambda self: -1, + Tensor.__setitem__: lambda self, k, v: -1, + Tensor.__setstate__: lambda self, d: -1, + Tensor.T.__get__: lambda self: -1, + Tensor._backward_hooks.__get__: lambda self: -1, + Tensor._base.__get__: lambda self: -1, + Tensor._cdata.__get__: lambda self: -1, + Tensor.grad.__get__: lambda self: -1, + Tensor._grad.__get__: lambda self: -1, + Tensor._grad_fn.__get__: lambda self: -1, + Tensor.grad_fn.__get__: lambda self: -1, + Tensor._version.__get__: lambda self: -1, + Tensor.data.__get__: lambda self: -1, + Tensor.device.__get__: lambda self: -1, + Tensor.dtype.__get__: lambda self: -1, + Tensor.is_cuda.__get__: lambda self: -1, + Tensor.is_leaf.__get__: lambda self: -1, + Tensor.is_meta.__get__: lambda self: -1, + Tensor.is_mkldnn.__get__: lambda self: -1, + Tensor.is_quantized.__get__: lambda self: -1, + Tensor.is_sparse.__get__: lambda self: -1, + Tensor.layout.__get__: lambda self: -1, + Tensor.name.__get__: lambda self: -1, + Tensor.names.__get__: lambda self: -1, + Tensor.ndim.__get__: lambda self: -1, + Tensor.output_nr.__get__: lambda self: -1, + Tensor.requires_grad.__get__: lambda self: -1, + Tensor.shape.__get__: lambda self: -1, + Tensor.volatile.__get__: lambda self: -1, + Tensor.real.__get__: lambda self: -1, + Tensor.imag.__get__: lambda self: -1, + Tensor.__cuda_array_interface__.__get__: lambda self: -1, + Tensor.type: lambda self, dtype=None, non_blocking=False, **kwargs: -1, + Tensor._coalesced_: lambda self: -1, + Tensor._dimI: lambda self: -1, + Tensor._dimV: lambda self: -1, + Tensor._indices: lambda self: -1, + Tensor._is_view: lambda self: -1, + Tensor._nnz: lambda self: -1, + Tensor._update_names: lambda self, names, inplace: -1, + Tensor._values: lambda self: -1, + Tensor.align_as: lambda self, other: -1, + Tensor.align_to: lambda self, order, ellipsis_idx: -1, + Tensor.apply_: lambda self, callable: -1, + Tensor.as_strided: lambda self, size, stride: -1, + Tensor.as_strided_: lambda self, size, stride: -1, + Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False: -1, + Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1, + Tensor.bool: lambda self, memory_format=torch.preserve_format: -1, + Tensor.byte: lambda self, memory_format=torch.preserve_format: -1, + Tensor.char: lambda self, memory_format=torch.preserve_format: -1, + Tensor.cauchy_: lambda self, median=0, sigma=1, *, generator=None: -1, + Tensor.coalesce: lambda self: -1, + Tensor._coalesced_: lambda self, coalesced: -1, + Tensor.contiguous: lambda self, memory_format=torch.contiguous_format: -1, + Tensor.copy_: lambda self, src, non_blocking=False: -1, + Tensor.cpu: lambda self, memory_format=torch.preserve_format: -1, + Tensor.cuda: lambda self, memory_format=torch.preserve_format: -1, + Tensor.data_ptr: lambda self: -1, + Tensor.dense_dim: lambda self: -1, + Tensor.dim: lambda self: -1, + Tensor.double: lambda self, memory_format=torch.preserve_format: -1, + Tensor.element_size: lambda self: -1, + Tensor.expand: lambda self, size: -1, + Tensor.expand_as: lambda self, other: -1, + Tensor.exponential_: lambda self, lambd=1, *, generator=None: -1, + Tensor.fill_: lambda self, value: -1, + Tensor.fill_diagonal_: lambda self, value: -1, + Tensor.float: lambda self, memory_format=torch.preserve_format: -1, + Tensor.geometric_: lambda self, p, *, generator=None: -1, + Tensor.get_device: lambda self: -1, + Tensor.half: lambda self, memory_format=torch.preserve_format: -1, + Tensor.has_names: lambda self: -1, + Tensor.indices: lambda self: -1, + Tensor.int: lambda self, memory_format=torch.preserve_format: -1, + Tensor.is_coalesced: lambda self: -1, + Tensor.is_contiguous: lambda self: -1, + Tensor.is_pinned: lambda self: -1, + Tensor.is_set_to: lambda self, tensor: -1, + Tensor.is_shared: lambda self: -1, + Tensor.item: lambda self: -1, + Tensor.log_normal_: lambda self, mean=1, std=2, *, generator=None: -1, + Tensor.log_softmax: lambda self, dim: -1, + Tensor.long: lambda self, memory_format=torch.preserve_format: -1, + Tensor.map_: lambda self, tensor, callable: -1, + Tensor.map2_: lambda self, x, y, callable: -1, + Tensor.mm: lambda self, mat2: -1, + Tensor.narrow_copy: lambda self, dimension, start, length: -1, + Tensor.ndimension: lambda self: -1, + Tensor.nelement: lambda self: -1, + Tensor.normal_: lambda self: -1, + Tensor.numpy: lambda self: -1, + Tensor.permute: lambda self, dim: -1, + Tensor.pin_memory: lambda self: -1, + Tensor.put_: lambda self, indices, tensor, accumulate=False: -1, + Tensor.qscheme: lambda self: -1, + Tensor.random_: lambda self, from_=0, to=None, *, generator=None: -1, + Tensor.record_stream: lambda self, stream: -1, + Tensor.refine_names: lambda self, names: -1, + Tensor.register_hook: lambda self, hook: -1, + Tensor.rename: lambda self, name: -1, + Tensor.repeat: lambda self, *size: -1, + Tensor.requires_grad_: lambda self, requires_grad=True: -1, + Tensor.reshape_as: lambda self, other: -1, + Tensor.resize: lambda self, *size: -1, + Tensor.resize_: lambda self, size: -1, + Tensor.resize_as: lambda self, other: -1, + Tensor.retain_grad: lambda self: -1, + Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1, + Tensor.share_memory_: lambda self: -1, + Tensor.short: lambda self, memory_format=torch.preserve_format: -1, + Tensor.size: lambda self: -1, + Tensor.sparse_dim: lambda self: -1, + Tensor.sparse_mask: lambda self, mask: -1, + Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1, + Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1, + Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1, + Tensor.storage: lambda self: -1, + Tensor.storage_offset: lambda self: -1, + Tensor.storage_type: lambda self: -1, + Tensor.sum_to_size: lambda self, size: -1, + Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1, + Tensor.to_dense: lambda self: -1, + Tensor.to_sparse: lambda self: -1, + Tensor.tolist: lambda self: -1, + Tensor.to_mkldnn: lambda self: -1, + Tensor.type_as: lambda self, other: -1, + Tensor.unfold: lambda self, dimension, size, step: -1, + Tensor.uniform_: lambda self, from_=0, to=1: -1, + Tensor.values: lambda self: -1, + Tensor.view: lambda self, shape: -1, + Tensor.view_as: lambda self, other: -1, + Tensor.zero_: lambda self: -1, + } ret2 = {} + ignored = get_ignored_functions() for k, v in ret.items(): + names = [ + k.__name__, + k.__name__ + "MSG", + "MSG" + k.__name__ + "MSG", + "MSG" + k.__name__ + "MSG", + "MSG" + k.__name__ + "MSG", + ] if k.__name__.startswith("MSG"): + subname = k.__name__[len("MSG"):] + names.extend([ + "MSG" + subname + "MSG", + "MSG" + subname + "MSG", + "MSG" + subname + "MSG" + ]) for name in names: + func = getattr(Tensor, name, None) + if callable(func) and func not in ret and func not in ignored: + ret2[func] = v ret.update(ret2) + return ret +def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]: + + + overloaded_types = [] + overloaded_args = [] + for arg in relevant_args: + arg_type = type(arg) + if (arg_type not in overloaded_types and hasattr(arg_type, 'MSG')): + if overloaded_types: + overloaded_types.append(arg_type) + index = len(overloaded_args) + for i, old_arg in enumerate(overloaded_args): + if issubclass(arg_type, type(old_arg)): + index = i + break + overloaded_args.insert(index, arg) + else: + overloaded_types = [arg_type] + overloaded_args = [arg] return overloaded_args +def handle_torch_function( + public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any: + + + overloaded_args = _get_overloaded_args(relevant_args) + + types = tuple(map(type, overloaded_args)) + for overloaded_arg in overloaded_args: + result = overloaded_arg.__torch_function__(public_api, types, args, kwargs) if result is not NotImplemented: + return result func_name = 'MSG'.format(public_api.__module__, public_api.__name__) + raise TypeError("MSG" + 'MSG' + .format(func_name, list(map(type, overloaded_args)))) def has_torch_function(relevant_args: Iterable[Any]) -> bool: + + return _is_torch_function_enabled() and any( + type(a) is not torch.Tensor and + getattr(a, 'MSG', _disabled_torch_function_impl) + is not _disabled_torch_function_impl + for a in relevant_args + ) @functools.lru_cache(None) +def get_overridable_functions() -> Dict[Any, List[Callable]]: + + overridable_funcs = collections.defaultdict(list) + tested_namespaces = [ + (torch, torch.__all__ + dir(torch._C._VariableFunctions)), + (torch.functional, torch.functional.__all__), + (torch.nn.functional, dir(torch.nn.functional)), + (torch.Tensor, dir(torch.Tensor)) + ] + for namespace, ns_funcs in tested_namespaces: + for func_name in ns_funcs: + if namespace is not torch.Tensor: + if func_name.startswith('MSG'): + continue + elif func_name.endswith('MSG'): + continue + elif not func_name[0].islower(): + continue + elif func_name == 'MSG': + continue + else: + func = getattr(namespace, func_name) + if getattr(object, func_name, None) == func: + continue + if func_name == 'MSG': + continue + func = getattr(namespace, func_name) + if namespace is torch.Tensor and getattr(object, func_name, None) == func: + continue + if isinstance(func, types.ModuleType): + continue + if isinstance(func, __future__._Feature): + continue if not callable(func) and hasattr(func, "MSG"): + overridable_funcs[func].append(func.__get__) + continue if not callable(func): + continue if func in get_ignored_functions(): + msg = ("MSG" + "MSG") + assert func not in get_testing_overrides(), msg.format(namespace, func.__name__) + continue + overridable_funcs[namespace].append(func) + return overridable_funcs @functools.lru_cache(None) +def get_tensor_methods() -> Set[Callable]: + + overridable_funcs = get_overridable_functions() + methods = set(overridable_funcs[torch.Tensor]) + return methods def is_tensor_method_or_property(func: Callable) -> bool: + + return func in get_tensor_methods() or func.__name__ == "MSG" def is_tensor_like(inp): + + return type(inp) is torch.Tensor or hasattr(inp, "MSG") +import torch +from typing import Optional +class SobolEngine(object): + r + MAXBIT = 30 + MAXDIM = 1111 def __init__(self, dimension, scramble=False, seed=None): + if dimension > self.MAXDIM or dimension < 1: + raise ValueError("MSG" + f"MSG") self.seed = seed + self.scramble = scramble + self.dimension = dimension cpu = torch.device("MSG") self.sobolstate = torch.zeros(dimension, self.MAXBIT, device=cpu, dtype=torch.long) + torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension) if self.scramble: + g: Optional[torch.Generator] = None + if self.seed is not None: + g = torch.Generator() + g.manual_seed(self.seed) shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g) + self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))) ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT) + ltm = torch.randint(2, ltm_dims, device=cpu, generator=g).tril() torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension) + else: + self.shift = torch.zeros(self.dimension, device=cpu, dtype=torch.long) self.quasi = self.shift.clone(memory_format=torch.contiguous_format) + self.num_generated = 0 def draw(self, n=1, out=None, dtype=torch.float32): + r + result, self.quasi = torch._sobol_engine_draw(self.quasi, n, self.sobolstate, + self.dimension, self.num_generated, dtype=dtype) + self.num_generated += n + if out is not None: + out.resize_as_(result).copy_(result) + return out + return result def reset(self): + r + self.quasi.copy_(self.shift) + self.num_generated = 0 + return self def fast_forward(self, n): + r + torch._sobol_engine_ff_(self.quasi, n, self.sobolstate, self.dimension, self.num_generated) + self.num_generated += n + return self def __repr__(self): + fmt_string = [f'MSG'] + if self.scramble: + fmt_string += ['MSG'] + if self.seed is not None: + fmt_string += [f'MSG'] + return self.__class__.__name__ + 'MSG' + 'MSG'.join(fmt_string) + 'MSG' +import contextlib +import warnings from torch._C import default_generator +import torch +def set_rng_state(new_state) -> None: + r + default_generator.set_state(new_state) +def get_rng_state() -> torch.Tensor: + r + return default_generator.get_state() +def manual_seed(seed) -> torch._C.Generator: + r + seed = int(seed) + import torch.cuda if not torch.cuda._is_in_bad_fork(): + torch.cuda.manual_seed_all(seed) return default_generator.manual_seed(seed) +def seed() -> int: + r + seed = default_generator.seed() + import torch.cuda if not torch.cuda._is_in_bad_fork(): + torch.cuda.manual_seed_all(seed) return seed +def initial_seed() -> int: + r + return default_generator.initial_seed() +_fork_rng_warned_already = False +@contextlib.contextmanager +def fork_rng(devices=None, enabled=True, _caller="MSG", _devices_kw="MSG"): + import torch.cuda + global _fork_rng_warned_already + + if not enabled: + yield + return if devices is None: + num_devices = torch.cuda.device_count() + if num_devices > 1 and not _fork_rng_warned_already: + warnings.warn( + ("MSG" + "MSG" + "MSG" + "MSG" + "MSG" + "MSG" + "MSG" + "MSG" + "MSG" + "MSG" + "MSG" + ).format(num_devices=num_devices, caller=_caller, devices_kw=_devices_kw)) + _fork_rng_warned_already = True + devices = list(range(num_devices)) + else: + devices = list(devices) cpu_rng_state = torch.get_rng_state() + gpu_rng_states = [] + for device in devices: + gpu_rng_states.append(torch.cuda.get_rng_state(device)) try: + yield + finally: + torch.set_rng_state(cpu_rng_state) + for device, gpu_rng_state in zip(devices, gpu_rng_states): + torch.cuda.set_rng_state(gpu_rng_state, device) +import difflib +import os +import io +import shutil +import struct +import sys +import torch +import tarfile +import tempfile +import warnings +from contextlib import closing, contextmanager +from ._utils import _import_dotted_name +from ._six import string_classes as _string_classes +from torch._utils_internal import get_source_lines_and_file +from torch.types import Storage +from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union +import copyreg +import pickle +import pathlib DEFAULT_PROTOCOL = 2 LONG_SIZE = struct.Struct('MSG').size +INT_SIZE = struct.Struct('MSG').size +SHORT_SIZE = struct.Struct('MSG').size MAGIC_NUMBER = 0x1950a86a20f9469cfc6c +PROTOCOL_VERSION = 1001 +STORAGE_KEY_SEPARATOR = 'MSG' class SourceChangeWarning(Warning): + pass +@contextmanager +def mkdtemp(): + path = tempfile.mkdtemp() + yield path + shutil.rmtree(path) +_package_registry = [] +def _is_zipfile(f) -> bool: + + + + + + + read_bytes = [] + start = f.tell() byte = f.read(1) + while byte != "MSG": + read_bytes.append(byte) + if len(read_bytes) == 4: + break + byte = f.read(1) + f.seek(start) local_header_magic_number = [b'MSG', b'MSG', b'MSG', b'MSG'] + return read_bytes == local_header_magic_number +def register_package(priority, tagger, deserializer): + queue_elem = (priority, tagger, deserializer) + _package_registry.append(queue_elem) + _package_registry.sort() +def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True): + + try: + version_strs = module.__version__.split('MSG') + module_version = tuple( + type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple) + ) + requirement_is_met = module_version >= req_version_tuple except Exception as e: + message = ( + "MSG" + "MSG" + ) % ( + module.__name__, module.__version__, str(req_version_tuple) + ) + if error_if_malformed: + raise RuntimeError(message) from e + else: + warnings.warn(message + 'MSG') + requirement_is_met = True return requirement_is_met +def _cpu_tag(obj): + if type(obj).__module__ == 'MSG': + return 'MSG' +def _cuda_tag(obj): + if type(obj).__module__ == 'MSG': + return 'MSG' + str(obj.get_device()) +def _cpu_deserialize(obj, location): + if location == 'MSG': + return obj +def validate_cuda_device(location): + device = torch.cuda._utils._get_device_index(location, True) if not torch.cuda.is_available(): + raise RuntimeError('MSG' + 'MSG' + 'MSG' + 'MSG'cpu\'MSG' + 'MSG') + device_count = torch.cuda.device_count() + if device >= device_count: + raise RuntimeError('MSG' + f'MSG' + 'MSG' + 'MSG') + return device +def _cuda_deserialize(obj, location): + if location.startswith('MSG'): + device = validate_cuda_device(location) + if getattr(obj, "MSG", False): + storage_type = getattr(torch.cuda, type(obj).__name__) + with torch.cuda.device(device): + return storage_type(obj.size()) + else: + return obj.cuda(device) +register_package(10, _cpu_tag, _cpu_deserialize) +register_package(20, _cuda_tag, _cuda_deserialize) +def location_tag(storage: Storage): + for _, tagger, _ in _package_registry: + location = tagger(storage) + if location: + return location + raise RuntimeError("MSG" + + torch.typename(storage)) +def default_restore_location(storage, location): + for _, _, fn in _package_registry: + result = fn(storage, location) + if result is not None: + return result + raise RuntimeError("MSG" + + torch.typename(storage) + "MSG" + + location + "MSG") +def normalize_storage_type(storage_type): + return getattr(torch, storage_type.__name__) +def storage_to_tensor_type(storage): + storage_type = type(storage) + module = _import_dotted_name(storage_type.__module__) + return getattr(module, storage_type.__name__.replace('MSG', 'MSG')) +def _is_path(name_or_buffer): + return isinstance(name_or_buffer, str) or \ + (sys.version_info[0] == 3 and isinstance(name_or_buffer, pathlib.Path)) +class _opener(object): + def __init__(self, file_like): + self.file_like = file_like def __enter__(self): + return self.file_like def __exit__(self, *args): + pass +class _open_file(_opener): + def __init__(self, name, mode): + super(_open_file, self).__init__(open(name, mode)) def __exit__(self, *args): + self.file_like.close() +class _open_buffer_reader(_opener): + def __init__(self, buffer): + super(_open_buffer_reader, self).__init__(buffer) + _check_seekable(buffer) +class _open_buffer_writer(_opener): + def __exit__(self, *args): + self.file_like.flush() +def _open_file_like(name_or_buffer, mode): + if _is_path(name_or_buffer): + return _open_file(name_or_buffer, mode) + else: + if 'MSG' in mode: + return _open_buffer_writer(name_or_buffer) + elif 'MSG' in mode: + return _open_buffer_reader(name_or_buffer) + else: + raise RuntimeError(f"MSG") +class _open_zipfile_reader(_opener): + def __init__(self, name_or_buffer) -> None: + super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer)) +class _open_zipfile_writer_file(_opener): + def __init__(self, name) -> None: + super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name))) def __exit__(self, *args) -> None: + self.file_like.write_end_of_file() +class _open_zipfile_writer_buffer(_opener): + def __init__(self, buffer) -> None: + self.buffer = buffer + super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer)) def __exit__(self, *args) -> None: + self.file_like.write_end_of_file() + self.buffer.flush() +def _open_zipfile_writer(name_or_buffer): + container: Type[_opener] + if _is_path(name_or_buffer): + container = _open_zipfile_writer_file + else: + container = _open_zipfile_writer_buffer + return container(name_or_buffer) +def _is_compressed_file(f) -> bool: + compress_modules = ['MSG'] + try: + return f.__module__ in compress_modules + except AttributeError: + return False +def _should_read_directly(f): + + if _is_compressed_file(f): + return False + try: + return f.fileno() >= 0 + except io.UnsupportedOperation: + return False + except AttributeError: + return False +def _check_seekable(f) -> bool: def raise_err_msg(patterns, e): + for p in patterns: + if p in str(e): + msg = (str(e) + "MSG" + + "MSG" + + "MSG") + raise type(e)(msg) + raise e try: + f.seek(f.tell()) + return True + except (io.UnsupportedOperation, AttributeError) as e: + raise_err_msg(["MSG", "MSG"], e) + return False def _check_dill_version(pickle_module) -> None: + + if pickle_module.__name__ == 'MSG': + required_dill_version = (0, 3, 1) + if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False): + raise ValueError(( + "MSG" + "MSG" + ) % ( + 'MSG'.join([str(num) for num in required_dill_version]), + pickle_module.__version__ + )) def save(obj, f: Union[str, os.PathLike, BinaryIO], + pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None: + + _check_dill_version(pickle_module) with _open_file_like(f, 'MSG') as opened_file: + if _use_new_zipfile_serialization: + with _open_zipfile_writer(opened_file) as opened_zipfile: + _save(obj, opened_zipfile, pickle_module, pickle_protocol) + return + _legacy_save(obj, opened_file, pickle_module, pickle_protocol) +def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: + import torch.nn as nn + serialized_container_types = {} + serialized_storages = {} def persistent_id(obj: Any) -> Optional[Tuple]: + if isinstance(obj, type) and issubclass(obj, nn.Module): + if obj in serialized_container_types: + return None + serialized_container_types[obj] = True + source_file = source = None + try: + source_lines, _, source_file = get_source_lines_and_file(obj) + source = 'MSG'.join(source_lines) + except Exception: + warnings.warn("MSG" + "MSG" + obj.__name__ + "MSG" + "MSG") + return ('MSG', obj, source_file, source) elif torch.is_storage(obj): + view_metadata: Optional[Tuple[str, int, int]] + obj = cast(Storage, obj) + storage_type = normalize_storage_type(type(obj)) + offset = 0 + obj_key = str(obj._cdata) + location = location_tag(obj) + serialized_storages[obj_key] = obj + is_view = obj._cdata != obj._cdata + if is_view: + view_metadata = (str(obj._cdata), offset, obj.size()) + else: + view_metadata = None return ('MSG', + storage_type, + obj_key, + location, + obj.size(), + view_metadata) + return None sys_info = dict( + protocol_version=PROTOCOL_VERSION, + little_endian=sys.byteorder == 'MSG', + type_sizes=dict( + short=SHORT_SIZE, + int=INT_SIZE, + long=LONG_SIZE, + ), + ) pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) + pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) + pickle_module.dump(sys_info, f, protocol=pickle_protocol) + pickler = pickle_module.Pickler(f, protocol=pickle_protocol) + pickler.persistent_id = persistent_id + pickler.dump(obj) serialized_storage_keys = sorted(serialized_storages.keys()) + pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) + f.flush() + for key in serialized_storage_keys: + serialized_storages[key]._write_file(f, _should_read_directly(f), True) +def _save(obj, zip_file, pickle_module, pickle_protocol): + serialized_storages = {} def persistent_id(obj): + if torch.is_storage(obj): + storage_type = normalize_storage_type(type(obj)) + obj_key = str(obj._cdata) + location = location_tag(obj) + serialized_storages[obj_key] = obj return ('MSG', + storage_type, + obj_key, + location, + obj.size()) + return None + data_buf = io.BytesIO() + pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) + pickler.persistent_id = persistent_id + pickler.dump(obj) + data_value = data_buf.getvalue() + zip_file.write_record('MSG', data_value, len(data_value)) + for key in sorted(serialized_storages.keys()): + name = f'MSG' + storage = serialized_storages[key] + if storage.device.type == 'MSG': + num_bytes = storage.size() * storage.element_size() + zip_file.write_record(name, storage.data_ptr(), num_bytes) + else: + buf = io.BytesIO() + storage._write_file(buf, _should_read_directly(buf), False) + buf_value = buf.getvalue() + zip_file.write_record(name, buf_value, len(buf_value)) +def load(f, map_location=None, pickle_module=pickle, **pickle_load_args): + + _check_dill_version(pickle_module) if 'MSG' not in pickle_load_args.keys(): + pickle_load_args['MSG'] = 'MSG' with _open_file_like(f, 'MSG') as opened_file: + if _is_zipfile(opened_file): + orig_position = opened_file.tell() + with _open_zipfile_reader(opened_file) as opened_zipfile: + if _is_torchscript_zip(opened_zipfile): + warnings.warn("MSG" + "MSG" + "MSG", UserWarning) + opened_file.seek(orig_position) + return torch.jit.load(opened_file) + return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args) + return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) +def _get_layout(name): + + cache = _get_layout.cache + if not cache: + for v in torch.__dict__.values(): + if isinstance(v, torch.layout): + cache[str(v)] = v + return cache[name] +_get_layout.cache = {} +copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) +def _legacy_load(f, map_location, pickle_module, **pickle_load_args): + deserialized_objects: Dict[int, Any] = {} restore_location = _get_restore_location(map_location) def _check_container_source(container_type, source_file, original_source): + try: + current_source = 'MSG'.join(get_source_lines_and_file(container_type)[0]) + except Exception: + warnings.warn("MSG" + "MSG" + container_type.__name__ + "MSG" + "MSG") + return + if original_source != current_source: + if container_type.dump_patches: + file_name = container_type.__name__ + 'MSG' + diff = difflib.unified_diff(current_source.split('MSG'), + original_source.split('MSG'), + source_file, + source_file, lineterm="MSG") + lines = 'MSG'.join(diff) + try: + with open(file_name, 'MSG') as f: + file_size = f.seek(0, 2) + f.seek(0) + if file_size == 0: + f.write(lines) + elif file_size != len(lines) or f.read() != lines: + raise IOError + msg = ("MSG" + file_name + "MSG" + "MSG" + file_name + "MSG" + "MSG") + except IOError: + msg = ("MSG" + "MSG" + file_name + "MSG" + "MSG" + "MSG") + else: + msg = ("MSG" + "MSG" + "MSG" + "MSG") + msg = f"MSG" + warnings.warn(msg, SourceChangeWarning) def legacy_load(f): + deserialized_objects: Dict[int, Any] = {} def persistent_load(saved_id): + if isinstance(saved_id, tuple): + if all(saved_id[1:]): + _check_container_source(*saved_id) + return saved_id[0] + return deserialized_objects[int(saved_id)] with closing(tarfile.open(fileobj=f, mode='MSG', format=tarfile.PAX_FORMAT)) as tar, \ + mkdtemp() as tmpdir: tar.extract('MSG', path=tmpdir) + with open(os.path.join(tmpdir, 'MSG'), 'MSG', 0) as f: + num_storages = pickle_module.load(f, **pickle_load_args) + for i in range(num_storages): + args = pickle_module.load(f, **pickle_load_args) + key, location, storage_type = args + obj = storage_type._new_with_file(f) + obj = restore_location(obj, location) + deserialized_objects[key] = obj storage_views = pickle_module.load(f, **pickle_load_args) + for target_cdata, root_cdata, offset, size in storage_views: + root = deserialized_objects[root_cdata] + deserialized_objects[target_cdata] = root[offset:offset + size] tar.extract('MSG', path=tmpdir) + with open(os.path.join(tmpdir, 'MSG'), 'MSG', 0) as f: + num_tensors = pickle_module.load(f, **pickle_load_args) + for _ in range(num_tensors): + args = pickle_module.load(f, **pickle_load_args) + key, storage_id, original_tensor_type = args + storage = deserialized_objects[storage_id] + tensor_type = storage_to_tensor_type(storage) + ndim, = struct.unpack('MSG', f.read(4)) + f.read(4) + size = struct.unpack(f'MSG', f.read(8 * ndim)) + stride = struct.unpack(f'MSG', f.read(8 * ndim)) + storage_offset, = struct.unpack('MSG', f.read(8)) + tensor = tensor_type().set_(storage, storage_offset, size, stride) + deserialized_objects[key] = tensor pickle_file = tar.extractfile('MSG') + unpickler = pickle_module.Unpickler(pickle_file, **pickle_load_args) + unpickler.persistent_load = persistent_load + result = unpickler.load() + return result deserialized_objects = {} def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] if typename == 'MSG': + if all(data[1:]): + _check_container_source(*data) + return data[0] + elif typename == 'MSG': + data_type, root_key, location, size, view_metadata = data + location = _maybe_decode_ascii(location) + if root_key not in deserialized_objects: + obj = data_type(size) + obj._torch_load_uninitialized = True + deserialized_objects[root_key] = restore_location(obj, location) + storage = deserialized_objects[root_key] + if view_metadata is not None: + view_key, offset, view_size = view_metadata + if view_key not in deserialized_objects: + deserialized_objects[view_key] = storage[offset:offset + view_size] + return deserialized_objects[view_key] + else: + return storage + else: + raise RuntimeError("MSG" % saved_id[0]) _check_seekable(f) + f_should_read_directly = _should_read_directly(f) if f_should_read_directly and f.tell() == 0: + try: + return legacy_load(f) + except tarfile.TarError: + if _is_zipfile(f): + raise RuntimeError( + f"MSG") from None + f.seek(0) if not hasattr(f, 'MSG') and (3, 8, 0) <= sys.version_info < (3, 8, 2): + raise RuntimeError( + "MSG" + f"MSG"{type(f)}\"MSG" + "MSG") magic_number = pickle_module.load(f, **pickle_load_args) + if magic_number != MAGIC_NUMBER: + raise RuntimeError("MSG") + protocol_version = pickle_module.load(f, **pickle_load_args) + if protocol_version != PROTOCOL_VERSION: + raise RuntimeError("MSG" % protocol_version) _sys_info = pickle_module.load(f, **pickle_load_args) + unpickler = pickle_module.Unpickler(f, **pickle_load_args) + unpickler.persistent_load = persistent_load + result = unpickler.load() deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) offset = f.tell() if f_should_read_directly else None + for key in deserialized_storage_keys: + assert key in deserialized_objects + deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly) + if offset is not None: + offset = f.tell() torch._utils._validate_loaded_sparse_tensors() return result +def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: + + + + + + + if isinstance(bytes_str, bytes): + return bytes_str.decode('MSG') + return bytes_str +def _get_restore_location(map_location): + if map_location is None: + restore_location = default_restore_location + elif isinstance(map_location, dict): + def restore_location(storage, location): + location = map_location.get(location, location) + return default_restore_location(storage, location) + elif isinstance(map_location, _string_classes): + def restore_location(storage, location): + return default_restore_location(storage, map_location) + elif isinstance(map_location, torch.device): + def restore_location(storage, location): + return default_restore_location(storage, str(map_location)) + else: + def restore_location(storage, location): + result = map_location(storage, location) + if result is None: + result = default_restore_location(storage, location) + return result + return restore_location +def _load(zip_file, map_location, pickle_module, pickle_file='MSG', **pickle_load_args): + restore_location = _get_restore_location(map_location) loaded_storages = {} def load_tensor(data_type, size, key, location): + name = f'MSG' + dtype = data_type(0).dtype storage = zip_file.get_storage_from_record(name, size, dtype).storage() + loaded_storages[key] = restore_location(storage, location) def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] assert typename == 'MSG', \ + f"MSG" + data_type, key, location, size = data + if key not in loaded_storages: + load_tensor(data_type, size, key, _maybe_decode_ascii(location)) + storage = loaded_storages[key] + return storage + data_file = io.BytesIO(zip_file.get_record(pickle_file)) + unpickler = pickle_module.Unpickler(data_file, **pickle_load_args) + unpickler.persistent_load = persistent_load + result = unpickler.load() torch._utils._validate_loaded_sparse_tensors() return result +def _is_torchscript_zip(zip_file): + return 'MSG' in zip_file.get_all_records() +import io import torch +from ._utils import _type, _cuda +class _StorageBase(object): + is_cuda = False + is_sparse = False def __str__(self): + content = 'MSG' + 'MSG'.join(str(self[i]) for i in range(len(self))) + return content + f'MSG' def __repr__(self): + return str(self) def __iter__(self): + return iter(map(lambda i: self[i], range(self.size()))) def __copy__(self): + return self.clone() def __deepcopy__(self, memo): + memo = memo.setdefault('MSG', {}) + if self._cdata in memo: + return memo[self._cdata] + new_storage = self.clone() + memo[self._cdata] = new_storage + return new_storage def __reduce__(self): + b = io.BytesIO() + torch.save(self, b, _use_new_zipfile_serialization=False) + return (_load_from_bytes, (b.getvalue(),)) def __sizeof__(self): + return super(_StorageBase, self).__sizeof__() + self.element_size() * self.size() def clone(self): + device = self.get_device() if self.is_cuda else -1 + with torch.cuda.device(device): + return type(self)(self.size()).copy_(self) def tolist(self): + return list(self) def cpu(self): + return self.type(getattr(torch, self.__class__.__name__)) def double(self): + return self.type(type(self).__module__ + 'MSG') def float(self): + return self.type(type(self).__module__ + 'MSG') def half(self): + return self.type(type(self).__module__ + 'MSG') def long(self): + return self.type(type(self).__module__ + 'MSG') def int(self): + return self.type(type(self).__module__ + 'MSG') def short(self): + return self.type(type(self).__module__ + 'MSG') def char(self): + return self.type(type(self).__module__ + 'MSG') def byte(self): + return self.type(type(self).__module__ + 'MSG') def bool(self): + return self.type(type(self).__module__ + 'MSG') def bfloat16(self): + return self.type(type(self).__module__ + 'MSG') def complex_double(self): + return self.type(type(self).__module__ + 'MSG') def complex_float(self): + return self.type(type(self).__module__ + 'MSG') def pin_memory(self): + if self.is_cuda: + raise TypeError(f"MSG") + import torch.cuda + allocator = torch.cuda._host_allocator() + return type(self)(self.size(), allocator=allocator).copy_(self) def share_memory_(self): + from torch.multiprocessing import get_sharing_strategy + if self.is_cuda: + pass + elif get_sharing_strategy() == 'MSG': + self._share_filename_() + else: + self._share_fd_() + return self @classmethod + def _new_shared(cls, size): + from torch.multiprocessing import get_sharing_strategy + if cls.is_cuda: + return cls(size) + elif get_sharing_strategy() == 'MSG': + return cls._new_using_filename(size) + else: + return cls._new_using_fd(size) +def _load_from_bytes(b): + return torch.load(io.BytesIO(b)) +_StorageBase.type = _type +_StorageBase.cuda = _cuda +import torch +import torch._C as _C +from torch._namedtensor_internals import update_names, check_serializing_named_tensor, resolve_ellipsis +from torch._namedtensor_internals import unzip_namedshape, single_ellipsis_index, is_ellipsis +from collections import OrderedDict +import torch.utils.hooks as hooks +import warnings +import weakref +from torch._C import _add_docstr +from typing import Any, Dict, Tuple, Union +from numbers import Number +import functools +from typing import Optional +def _wrap_type_error_to_not_implemented(f): + + method_assignments = ('MSG', 'MSG') + assigned = functools.WRAPPER_ASSIGNMENTS @functools.wraps(f, assigned=assigned) + def wrapped(*args, **kwargs): + from torch.overrides import has_torch_function, handle_torch_function + if not all(type(t) is Tensor for t in args) and has_torch_function(args): + return handle_torch_function(wrapped, args, *args, **kwargs) + try: + return f(*args, **kwargs) + except TypeError: + return NotImplemented + return wrapped class Tensor(torch._C._TensorBase): + def __deepcopy__(self, memo): + from torch.overrides import has_torch_function, handle_torch_function + relevant_args = (self,) + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__deepcopy__, relevant_args, self, memo) + if not self.is_leaf: + raise RuntimeError("MSG" + "MSG") + if id(self) in memo: + return memo[id(self)] + with torch.no_grad(): + if self.is_sparse or self.device.type == 'MSG': + new_tensor = self.clone() + else: + new_storage = self.storage().__deepcopy__(memo) + if self.is_quantized: + quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[torch.qscheme, Tensor, Tensor, int]] + if self.qscheme() == torch.per_tensor_affine: + quantizer_params = self.qscheme(), self.q_scale(), self.q_zero_point() + elif self.qscheme() in (torch.per_channel_affine, torch.per_channel_affine_float_qparams): + quantizer_params = self.qscheme(), \ + self.q_per_channel_scales(), \ + self.q_per_channel_zero_points(), \ + self.q_per_channel_axis() + else: + raise RuntimeError(f"MSG") + new_tensor = torch._utils._rebuild_qtensor( + new_storage, + self.storage_offset(), + self.size(), + self.stride(), + quantizer_params, + self.requires_grad, + self._backward_hooks) + else: + new_tensor = self.new() + new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride()) + new_tensor.requires_grad = self.requires_grad + memo[id(self)] = new_tensor + return new_tensor def __reduce_ex__(self, proto): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__reduce_ex__, relevant_args, self, proto) + check_serializing_named_tensor(self) + torch.utils.hooks.warn_if_has_hooks(self) + backward_hooks: Dict[Any, Any] = OrderedDict() + if self.device.type == 'MSG': + arg_xla = (self.cpu().numpy(), + self.dtype, + str(self.device), + self.requires_grad) + return (torch._utils._rebuild_xla_tensor, arg_xla) + if self.is_quantized: + quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]] + if self.qscheme() == torch.per_tensor_affine: + quantizer_params = (torch.per_tensor_affine, + self.q_scale(), + self.q_zero_point()) + elif self.qscheme() in (torch.per_channel_affine, torch.per_channel_affine_float_qparams): + quantizer_params = (torch.per_channel_affine, + self.q_per_channel_scales(), + self.q_per_channel_zero_points(), + self.q_per_channel_axis()) + else: + raise RuntimeError(f"MSG") + args_qtensor = (self.storage(), + self.storage_offset(), + tuple(self.size()), + self.stride(), + quantizer_params, + self.requires_grad, + backward_hooks) + return (torch._utils._rebuild_qtensor, args_qtensor) + elif self.is_sparse: + if self.layout == torch.sparse_coo: + args_sparse = (self.layout, + (self._indices(), + self._values(), + self.size())) + else: + raise NotImplementedError( + 'MSG' % (self.layout)) + return (torch._utils._rebuild_sparse_tensor, args_sparse) + else: + args = (self.storage(), + self.storage_offset(), + tuple(self.size()), + self.stride(), + self.requires_grad, + backward_hooks) + return (torch._utils._rebuild_tensor_v2, args) def __setstate__(self, state): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__setstate__, relevant_args, self, state) + if not self.is_leaf: + raise RuntimeError('MSG') + if len(state) == 4: + self.set_(*state) + return + elif len(state) == 5: + self.data = state[0] + state = (state[3], state[4], state[2]) + self.requires_grad, _, self._backward_hooks = state def __repr__(self): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__repr__, relevant_args, self) + return torch._tensor_str._str(self) def backward(self, gradient=None, retain_graph=None, create_graph=False): + r + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function( + Tensor.backward, + relevant_args, + self, + gradient=gradient, + retain_graph=retain_graph, + create_graph=create_graph) + torch.autograd.backward(self, gradient, retain_graph, create_graph) def register_hook(self, hook): + r + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.register_hook, relevant_args, self, hook) + if not self.requires_grad: + raise RuntimeError("MSG" + "MSG") + if self._backward_hooks is None: + self._backward_hooks = OrderedDict() + if self.grad_fn is not None: + self.grad_fn._register_hook_dict(self) + handle = hooks.RemovableHandle(self._backward_hooks) + self._backward_hooks[handle.id] = hook + return handle def reinforce(self, reward): + def trim(str): + return 'MSG'.join([line.strip() for line in str.split('MSG')]) raise RuntimeError(trim(r)) detach = _add_docstr(_C._TensorBase.detach, r) detach_ = _add_docstr(_C._TensorBase.detach_, r) def retain_grad(self): + r + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.retain_grad, relevant_args, self) + if not self.requires_grad: + raise RuntimeError("MSG") + if self.is_leaf: + return + if hasattr(self, 'MSG'): + return + weak_self = weakref.ref(self) def retain_grad_hook(grad): + var = weak_self() + if var is None: + return + if var._grad is None: + if grad.is_sparse: + var._grad = grad.clone() + else: + var._grad = grad.clone(memory_format=torch.contiguous_format) + else: + var._grad = var._grad + grad self.register_hook(retain_grad_hook) + self.retains_grad = True def is_shared(self): + r + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.is_shared, relevant_args, self) + return self.storage().is_shared() def share_memory_(self): + r + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.share_memory_, relevant_args, self) + self.storage().share_memory_() + return self def __reversed__(self): + r + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__reversed__, relevant_args, self) + if self.dim() == 0: + return self + else: + return self.flip(0) def norm(self, p="MSG", dim=None, keepdim=False, dtype=None): + r + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.norm, relevant_args, self, p=p, dim=dim, keepdim=keepdim, dtype=dtype) + return torch.norm(self, p, dim, keepdim, dtype=dtype) def lu(self, pivot=True, get_infos=False): + r + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.lu, relevant_args, self, pivot=pivot, get_infos=get_infos) + LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos)) + if get_infos: + return LU, pivots, infos + else: + return LU, pivots def stft(self, n_fft: int, hop_length: Optional[int] = None, + win_length: Optional[int] = None, window: 'MSG' = None, + center: bool = True, pad_mode: str = 'MSG', normalized: bool = False, + onesided: Optional[bool] = None, return_complex: Optional[bool] = None): + r + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function( + Tensor.stft, relevant_args, self, n_fft, hop_length=hop_length, + win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized, + onesided=onesided, return_complex=return_complex + ) + return torch.stft(self, n_fft, hop_length, win_length, window, center, + pad_mode, normalized, onesided, return_complex=return_complex) def istft(self, n_fft: int, hop_length: Optional[int] = None, + win_length: Optional[int] = None, window: 'MSG' = None, + center: bool = True, normalized: bool = False, + onesided: Optional[bool] = None, length: Optional[int] = None, + return_complex: bool = False): + r + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function( + Tensor.istft, relevant_args, self, n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=center, normalized=normalized, onesided=onesided, length=length, + return_complex=return_complex + ) + return torch.istft(self, n_fft, hop_length, win_length, window, center, + normalized, onesided, length, return_complex=return_complex) def resize(self, *sizes): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.resize, relevant_args, self, *sizes) + warnings.warn("MSG") + from torch.autograd._functions import Resize + return Resize.apply(self, sizes) def resize_as(self, tensor): + relevant_args = (self, tensor) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and type(tensor) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.resize_as, relevant_args, self, tensor) + warnings.warn("MSG") + from torch.autograd._functions import Resize + return Resize.apply(self, tensor.size()) def split(self, split_size, dim=0): + r + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.split, relevant_args, self, split_size, dim=dim) + if isinstance(split_size, int): + return super(Tensor, self).split(split_size, dim) + elif isinstance(split_size, Tensor): + try: + split_size = int(split_size) + return super(Tensor, self).split(split_size, dim) + except ValueError: + return super(Tensor, self).split_with_sizes(split_size, dim) + else: + return super(Tensor, self).split_with_sizes(split_size, dim) def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None): + r + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function( + Tensor.unique, relevant_args, self, sorted=sorted, return_inverse=return_inverse, + return_counts=return_counts, dim=dim + ) + return torch.unique(self, sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim) def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None): + r + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function( + Tensor.unique_consecutive, relevant_args, self, return_inverse=return_inverse, + return_counts=return_counts, dim=dim + ) + return torch.unique_consecutive(self, return_inverse=return_inverse, return_counts=return_counts, dim=dim) def __rsub__(self, other): + relevant_args = (self, other) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__rsub__, relevant_args, self, other) + return _C._VariableFunctions.rsub(self, other) def __rdiv__(self, other): + relevant_args = (self, other) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__rdiv__, relevant_args, self, other) + if self.dtype.is_floating_point or self.dtype.is_complex: + return self.reciprocal() * other + else: + return self.to(torch.get_default_dtype()).reciprocal() * other __rtruediv__ = __rdiv__ + __itruediv__ = _C._TensorBase.__idiv__ __pow__ = _C._TensorBase.pow def __format__(self, format_spec): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__format__, relevant_args, self, format_spec) + if self.dim() == 0: + return self.item().__format__(format_spec) + return object.__format__(self, format_spec) def __ipow__(self, other): + relevant_args = (self, other) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__ipow__, relevant_args, self, other) + return NotImplemented @_wrap_type_error_to_not_implemented + def __rpow__(self, other): + dtype = torch.result_type(other, self) + return torch.tensor(other, dtype=dtype, device=self.device) ** self @_wrap_type_error_to_not_implemented + def __floordiv__(self, other): + return torch.floor_divide(self, other) @_wrap_type_error_to_not_implemented + def __rfloordiv__(self, other): + return torch.floor_divide(other, self) __neg__ = _C._TensorBase.neg __eq__ = _wrap_type_error_to_not_implemented(_C._TensorBase.eq) + __ne__ = _wrap_type_error_to_not_implemented(_C._TensorBase.ne) + __lt__ = _wrap_type_error_to_not_implemented(_C._TensorBase.lt) + __le__ = _wrap_type_error_to_not_implemented(_C._TensorBase.le) + __gt__ = _wrap_type_error_to_not_implemented(_C._TensorBase.gt) + __ge__ = _wrap_type_error_to_not_implemented(_C._TensorBase.ge) + __abs__ = _C._TensorBase.abs def __len__(self): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__len__, relevant_args, self) + if self.dim() == 0: + raise TypeError("MSG") + return self.shape[0] def __iter__(self): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__iter__, relevant_args, self) + if self.dim() == 0: + raise TypeError('MSG') + if torch._C._get_tracing_state(): + warnings.warn('MSG' + 'MSG't change the number of 'MSG'iterations executed (and might lead to errors or silently give 'MSG'incorrect results).'MSG'Only a small subset of methods are supported for quantized tensors.'MSG'volatile'MSG'uint8'MSG'refine_names'MSG'align_to'MSG'torch'MSG'Storage'MSG'Storage'MSG'Storage'MSG'1.7.1+cu101'MSG'10.1'MSG'57bffc3a8e4fee0cce31e1ff1f662ccf7b16db57'MSG'java'MSG'Windows'MSG'win32'MSG'Mac'MSG'darwin'MSG'linux2'MSG'darwin'MSG'~/Library/Application Support/'MSG'XDG_DATA_HOME'MSG'darwin'MSG'/Library/Application Support'MSG'XDG_DATA_DIRS'MSG'/usr/local/share'MSG'/usr/share'MSG'darwin'MSG'~/Library/Preferences/'MSG'XDG_CONFIG_HOME'MSG'win32'MSG'darwin'MSG'/Library/Preferences'MSG'XDG_CONFIG_DIRS'MSG'/etc/xdg'MSG'darwin'MSG'~/Library/Caches'MSG'XDG_CACHE_HOME'MSG'~/.cache'MSG'XDG_STATE_HOME'MSG'~/Library/Logs'MSG'c'MSG'c'MSG'torch.classes'MSG'Class {self.name}.{attr} not registered!'MSG'torch.classes'MSG'weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]'MSG'.'MSG'.'MSG'.'MSG','MSG'['MSG']'MSG'['MSG'['MSG']'MSG'__code__'MSG'__func__'MSG'_torchscript_modifier'MSG'__module__'MSG'typing'MSG'__origin__'MSG'__origin__'MSG'__module__'MSG'typing'MSG'__origin__'MSG'__origin__'MSG'__module__'MSG'typing'MSG'__origin__'MSG'__origin__'MSG'__module__'MSG'typing'MSG'__origin__'MSG'__args__'MSG'typing'MSG'__origin__'MSG'typing'MSG'typing_extensions'MSG'__origin__'MSG'_jit_override_qualname'MSG''MSG'_lambda'MSG''MSG'_fields'MSG'__annotations__'MSG'lobpcg'MSG'inf'MSG'lobpcg.backward does not support sparse input yet.'MSG'Note that lobpcg.forward does though.'MSG'lobpcg.backward does not support complex input yet.'MSG'Note that lobpcg.forward does though.'MSG'lobpcg.backward does not support backward with B != I yet.'MSG'Script and require grads is not supported atm.'MSG'If you just want to do the forward, use .detach()'MSG'on A and B before calling into lobpcg'MSG'LPBPCG algorithm is not applicable when the number of A rows (={})'MSG' is smaller than 3 x the number of requested eigenpairs (={})'MSG'ortho'MSG'm'MSG'n'MSG'k'MSG'niter'MSG'tol'MSG'largest'MSG'ortho'MSG'ortho_i_max'MSG'ortho_i_max'MSG'ortho_j_max'MSG'ortho_j_max'MSG'ortho_tol'MSG'ortho_tol'MSG'ortho_tol_drop'MSG'ortho_tol_drop'MSG'ortho_tol_replace'MSG'ortho_tol_replace'MSG'ortho_use_drop'MSG'ortho_use_drop'MSG'batch_index'MSG'm'MSG'n'MSG'istep'MSG'_'MSG'_'MSG'LOPBCG:'MSG' iparams={}'MSG' fparams={}'MSG' bparams={}'MSG' ivars={}'MSG' fvars={}'MSG' bvars={}'MSG' tvars={}'MSG' A={}'MSG' B={}'MSG' iK={}'MSG' X={}'MSG' E={}'MSG''MSG'\n'MSG'istep'MSG'X_norm'MSG'A_norm'MSG'B_norm'MSG'iterations_left'MSG'niter'MSG'converged_count'MSG'converged_end'MSG'ortho'MSG'iterations_left'MSG'iterations_left'MSG'istep'MSG'istep'MSG'converged_count'MSG'tol'MSG'A_norm'MSG'B_norm'MSG'the number of converged eigenpairs 'MSG'(was {}, got {}) cannot decrease'MSG'converged_count'MSG'rerr'MSG'force_stop'MSG'iterations_left'MSG'converged_count'MSG'k'MSG'converged_end'MSG'converged_count'MSG'n'MSG'largest'MSG'istep'MSG'converged_end'MSG'converged_end'MSG'converged_end'MSG'converged_count'MSG'n'MSG'largest'MSG'istep'MSG'converged_end'MSG'converged_end'MSG'm'MSG'ortho_tol'MSG'ortho_tol_drop'MSG'ortho_tol_replace'MSG'ortho_i_max'MSG'ortho_j_max'MSG'ortho_use_drop'MSG'ortho_'MSG'_rerr'MSG'ortho_i'MSG'ortho_j'MSG''MSG'ortho_i'MSG'ortho_j'MSG'ortho_UBUmI_rerr[{}, {}]'MSG'ortho_VBU_rerr[{}]'MSG'Overdetermined shape of U:'MSG' + .format(B.shape[-1], U.shape[-1], V.shape[-1])) + self.ivars['MSG'] = i + self.ivars['MSG'] = j + return U +LOBPCG_call_tracker_orig = LOBPCG.call_tracker +def LOBPCG_call_tracker(self): + self.tracker(self) +__all__ = ['MSG', 'MSG'] from typing import Tuple, Optional import torch +from torch import Tensor +from . import _linalg_utils as _utils +from .overrides import has_torch_function, handle_torch_function +def get_approximate_basis(A, q, niter=2, + M=None + ): + + niter = 2 if niter is None else niter + m, n = A.shape[-2:] + dtype = _utils.get_floating_dtype(A) + matmul = _utils.matmul R = torch.randn(n, q, dtype=dtype, device=A.device) A_H = _utils.transjugate(A) + if M is None: + (Q, _) = matmul(A, R).qr() + for i in range(niter): + (Q, _) = matmul(A_H, Q).qr() + (Q, _) = matmul(A, Q).qr() + else: + M_H = _utils.transjugate(M) + (Q, _) = (matmul(A, R) - matmul(M, R)).qr() + for i in range(niter): + (Q, _) = (matmul(A_H, Q) - matmul(M_H, Q)).qr() + (Q, _) = (matmul(A, Q) - matmul(M, Q)).qr() return Q +def svd_lowrank(A, q=6, niter=2, M=None): + + r + if not torch.jit.is_scripting(): + tensor_ops = (A, M) + if (not set(map(type, tensor_ops)).issubset((torch.Tensor, type(None))) and has_torch_function(tensor_ops)): + return handle_torch_function(svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M) + return _svd_lowrank(A, q=q, niter=niter, M=M) +def _svd_lowrank(A, q=6, niter=2, M=None): + + q = 6 if q is None else q + m, n = A.shape[-2:] + matmul = _utils.matmul + if M is None: + M_t = None + else: + M_t = _utils.transpose(M) + A_t = _utils.transpose(A) + + if m < n: + Q = get_approximate_basis(A_t, q, niter=niter, M=M_t) + Q_c = _utils.conjugate(Q) + if M is None: + B_t = matmul(A, Q_c) + else: + B_t = matmul(A, Q_c) - matmul(M, Q_c) + U, S, V = torch.svd(B_t) + V = Q.matmul(V) + else: + Q = get_approximate_basis(A, q, niter=niter, M=M) + Q_c = _utils.conjugate(Q) + if M is None: + B = matmul(A_t, Q_c) + else: + B = matmul(A_t, Q_c) - matmul(M_t, Q_c) + U, S, V = torch.svd(_utils.transpose(B)) + U = Q.matmul(U) return U, S, V +def pca_lowrank(A, q=None, center=True, niter=2): + + r if not torch.jit.is_scripting(): + if type(A) is not torch.Tensor and has_torch_function((A,)): + return handle_torch_function(pca_lowrank, (A,), A, q=q, center=center, niter=niter) (m, n) = A.shape[-2:] if q is None: + q = min(6, m, n) + elif not (q >= 0 and q <= min(m, n)): + raise ValueError('MSG' + 'MSG' + .format(q, min(m, n))) + if not (niter >= 0): + raise ValueError('MSG' + .format(niter)) dtype = _utils.get_floating_dtype(A) if not center: + return _svd_lowrank(A, q, niter=niter, M=None) if _utils.is_sparse(A): + if len(A.shape) != 2: + raise ValueError('MSG') + c = torch.sparse.sum(A, dim=(-2,)) / m + column_indices = c.indices()[0] + indices = torch.zeros(2, len(column_indices), + dtype=column_indices.dtype, + device=column_indices.device) + indices[0] = column_indices + C_t = torch.sparse_coo_tensor( + indices, c.values(), (n, 1), dtype=dtype, device=A.device) ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device) + M = _utils.transpose(torch.sparse.mm(C_t, ones_m1_t)) + return _svd_lowrank(A, q, niter=niter, M=M) + else: + C = A.mean(dim=(-2,), keepdim=True) + return _svd_lowrank(A - C, q, niter=niter, M=None) +from collections import OrderedDict +def check_serializing_named_tensor(tensor): + if tensor.has_names(): + raise RuntimeError( + "MSG" + "MSG") +def build_dim_map(tensor): + + return OrderedDict([(idx if name is None else name, name) + for idx, name in enumerate(tensor.names)]) +def unzip_namedshape(namedshape): + if isinstance(namedshape, OrderedDict): + namedshape = namedshape.items() + if not hasattr(namedshape, 'MSG') and not isinstance(namedshape, tuple): + raise RuntimeError( + 'MSG' + .format(type(namedshape))) + if len(namedshape) == 0: + raise RuntimeError('MSG') + return zip(*namedshape) +def namer_api_name(inplace): + if inplace: + return 'MSG' + else: + return 'MSG' +def is_ellipsis(item): + return item == Ellipsis or item == 'MSG' def single_ellipsis_index(names, fn_name): + ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)] + if len(ellipsis_indices) >= 2: + raise RuntimeError('MSG'...\'MSG' + 'MSG' + .format(fn_name, names)) + if len(ellipsis_indices) == 1: + return ellipsis_indices[0] + return None def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names): + return names[numel_pre_glob:len(names) - numel_post_glob] +def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names): + globbed_names = expand_single_ellipsis(ellipsis_idx, len(names) - ellipsis_idx - 1, tensor_names) + return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1:] +def resolve_ellipsis(names, tensor_names, fn_name): + + ellipsis_idx = single_ellipsis_index(names, fn_name) + if ellipsis_idx is None: + return names + return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names) +def update_names_with_list(tensor, names, inplace): + + if len(names) == 1 and names[0] is None: + return tensor._update_names(None, inplace) return tensor._update_names( + resolve_ellipsis(names, tensor.names, namer_api_name(inplace)), inplace) +def update_names_with_mapping(tensor, rename_map, inplace): + dim_map = build_dim_map(tensor) + for old_dim in rename_map.keys(): + new_dim = rename_map[old_dim] + if old_dim in dim_map.keys(): + dim_map[old_dim] = new_dim + else: + raise RuntimeError(('MSG'{old_dim}\'MSG' + 'MSG'{old_dim}\'MSG') + .format(old_dim=old_dim, new_dim=new_dim, dims=tensor.names, + api_name=namer_api_name(inplace))) + return tensor._update_names(tuple(dim_map.values()), inplace) +def update_names(tensor, names, rename_map, inplace): + + has_names = len(names) > 0 + has_rename_pairs = bool(rename_map) + if has_names and has_rename_pairs: + raise RuntimeError('MSG' + 'MSG' + 'MSG' + 'MSG'.format(api_name=namer_api_name(inplace))) + if not has_names and not has_rename_pairs: + return update_names_with_list(tensor, names, inplace) if has_names: + return update_names_with_list(tensor, names, inplace) + return update_names_with_mapping(tensor, rename_map, inplace) +import torch._C import contextlib +import ctypes +import os +import sys +import types import torch.jit +import torch._utils_internal +_SET_GLOBAL_FLAGS = hasattr(sys, 'MSG') and hasattr(sys, 'MSG') +@contextlib.contextmanager +def dl_open_guard(): + + if _SET_GLOBAL_FLAGS: + old_flags = sys.getdlopenflags() + sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL) + yield + if _SET_GLOBAL_FLAGS: + sys.setdlopenflags(old_flags) class _OpNamespace(types.ModuleType): + + def __init__(self, name): + super(_OpNamespace, self).__init__('MSG' + name) + self.name = name def __getattr__(self, op_name): + qualified_op_name = 'MSG'.format(self.name, op_name) + op = torch._C._jit_get_operation(qualified_op_name) + torch.jit._builtins._register_builtin(op, qualified_op_name) + setattr(self, op_name, op) + op.__module__ = self.__module__ + "MSG" + self.name + return op class _Ops(types.ModuleType): + __file__ = os.path.join(os.path.dirname(__file__), 'MSG') def __init__(self): + super(_Ops, self).__init__('MSG') + self.loaded_libraries = set() def __getattr__(self, name): + namespace = _OpNamespace(name) + setattr(self, name, namespace) + return namespace def load_library(self, path): + path = torch._utils_internal.resolve_library_path(path) + with dl_open_guard(): + ctypes.CDLL(path) + self.loaded_libraries.add(path) +ops = _Ops() +import builtins +import collections.abc +import io +import math +import sys +import types +import queue inf = math.inf +nan = math.nan +string_classes = (str, bytes) +int_classes = int +FileNotFoundError = builtins.FileNotFoundError +StringIO = io.StringIO +container_abcs = collections.abc +PY3 = sys.version_info[0] == 3 +PY37 = sys.version_info[0] == 3 and sys.version_info[1] >= 7 def with_metaclass(meta: type, *bases) -> type: + + + + + class metaclass(meta): def __new__(cls, name, this_bases, d): + return meta(name, bases, d) @classmethod + def __prepare__(cls, name, this_bases): + return meta.__prepare__(name, bases) return type.__new__(metaclass, 'MSG', (), {}) def get_function_from_type(cls, name): + return getattr(cls, name, None) +def istuple(obj) -> bool: + + + + + t = type(obj) + return isinstance(obj, tuple) or t.__module__ == 'MSG' def bind_method(fn, obj, obj_type): + return types.MethodType(fn, obj) +import torch._C +from torch._C import _add_docstr as add_docstr +storage_classes = [ + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', +] +def add_docstr_all(method, docstr): + for cls_name in storage_classes: + cls = getattr(torch._C, cls_name) + try: + add_docstr(getattr(cls, method), docstr) + except AttributeError: + pass +add_docstr_all('MSG', + ) +import torch._C +from torch._C import _add_docstr as add_docstr +from ._torch_docs import parse_kwargs +def add_docstr_all(method, docstr): + add_docstr(getattr(torch._C._TensorBase, method), docstr) common_args = parse_kwargs() new_common_args = parse_kwargs() add_docstr_all('MSG', + r.format(**new_common_args)) add_docstr_all('MSG', + r.format(**new_common_args)) add_docstr_all('MSG', + r.format(**new_common_args)) add_docstr_all('MSG', + r.format(**new_common_args)) add_docstr_all('MSG', + r.format(**new_common_args)) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r.format(**common_args)) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r.format(**common_args)) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r.format(**common_args)) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r.format(**common_args)) add_docstr_all('MSG', + r.format(**common_args)) add_docstr_all('MSG', + r.format(**common_args)) add_docstr_all('MSG', + r.format(**common_args)) add_docstr_all('MSG', + r.format(**common_args)) add_docstr_all('MSG', + r.format(**common_args)) add_docstr_all('MSG', + r.format(**common_args)) add_docstr_all('MSG', + r.format(**common_args)) add_docstr_all('MSG', + r.format(**common_args)) add_docstr_all('MSG', + r) +add_docstr_all('MSG', + r.format(**common_args)) add_docstr_all('MSG', + r.format(**common_args)) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) +add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) add_docstr_all('MSG', + r) +import math +import torch +from torch._six import inf +class __PrinterOptions(object): + precision = 4 + threshold = 1000 + edgeitems = 3 + linewidth = 80 + sci_mode = None +PRINT_OPTS = __PrinterOptions() def set_printoptions( + precision=None, + threshold=None, + edgeitems=None, + linewidth=None, + profile=None, + sci_mode=None +): + r + if profile is not None: + if profile == "MSG": + PRINT_OPTS.precision = 4 + PRINT_OPTS.threshold = 1000 + PRINT_OPTS.edgeitems = 3 + PRINT_OPTS.linewidth = 80 + elif profile == "MSG": + PRINT_OPTS.precision = 2 + PRINT_OPTS.threshold = 1000 + PRINT_OPTS.edgeitems = 2 + PRINT_OPTS.linewidth = 80 + elif profile == "MSG": + PRINT_OPTS.precision = 4 + PRINT_OPTS.threshold = inf + PRINT_OPTS.edgeitems = 3 + PRINT_OPTS.linewidth = 80 if precision is not None: + PRINT_OPTS.precision = precision + if threshold is not None: + PRINT_OPTS.threshold = threshold + if edgeitems is not None: + PRINT_OPTS.edgeitems = edgeitems + if linewidth is not None: + PRINT_OPTS.linewidth = linewidth + PRINT_OPTS.sci_mode = sci_mode +class _Formatter(object): + def __init__(self, tensor): + self.floating_dtype = tensor.dtype.is_floating_point + self.int_mode = True + self.sci_mode = False + self.max_width = 1 with torch.no_grad(): + tensor_view = tensor.reshape(-1) if not self.floating_dtype: + for value in tensor_view: + value_str = 'MSG'.format(value) + self.max_width = max(self.max_width, len(value_str)) else: + nonzero_finite_vals = torch.masked_select(tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)) if nonzero_finite_vals.numel() == 0: + return nonzero_finite_abs = nonzero_finite_vals.abs().double() + nonzero_finite_min = nonzero_finite_abs.min().double() + nonzero_finite_max = nonzero_finite_abs.max().double() for value in nonzero_finite_vals: + if value != torch.ceil(value): + self.int_mode = False + break if self.int_mode: + if nonzero_finite_max / nonzero_finite_min > 1000. or nonzero_finite_max > 1.e8: + self.sci_mode = True + for value in nonzero_finite_vals: + value_str = ('MSG').format(PRINT_OPTS.precision).format(value) + self.max_width = max(self.max_width, len(value_str)) + else: + for value in nonzero_finite_vals: + value_str = ('MSG').format(value) + self.max_width = max(self.max_width, len(value_str) + 1) + else: + if nonzero_finite_max / nonzero_finite_min > 1000.\ + or nonzero_finite_max > 1.e8\ + or nonzero_finite_min < 1.e-4: + self.sci_mode = True + for value in nonzero_finite_vals: + value_str = ('MSG').format(PRINT_OPTS.precision).format(value) + self.max_width = max(self.max_width, len(value_str)) + else: + for value in nonzero_finite_vals: + value_str = ('MSG').format(PRINT_OPTS.precision).format(value) + self.max_width = max(self.max_width, len(value_str)) if PRINT_OPTS.sci_mode is not None: + self.sci_mode = PRINT_OPTS.sci_mode def width(self): + return self.max_width def format(self, value): + if self.floating_dtype: + if self.sci_mode: + ret = ('MSG').format(self.max_width, PRINT_OPTS.precision).format(value) + elif self.int_mode: + ret = 'MSG'.format(value) + if not (math.isinf(value) or math.isnan(value)): + ret += 'MSG' + else: + ret = ('MSG').format(PRINT_OPTS.precision).format(value) + else: + ret = 'MSG'.format(value) + return (self.max_width - len(ret)) * 'MSG' + ret +def _scalar_str(self, formatter1, formatter2=None): + if formatter2 is not None: + real_str = _scalar_str(self.real, formatter1) + imag_str = _scalar_str(self.imag, formatter2) + "MSG" + if self.imag < 0: + return real_str + imag_str.lstrip() + else: + return real_str + "MSG" + imag_str.lstrip() + else: + return formatter1.format(self.item()) def _vector_str(self, indent, summarize, formatter1, formatter2=None): + + element_length = formatter1.width() + 2 + if formatter2 is not None: + element_length += formatter2.width() + 1 elements_per_line = max(1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length)))) + char_per_line = element_length * elements_per_line def _val_formatter(val, formatter1=formatter1, formatter2=formatter2): + if formatter2 is not None: + real_str = formatter1.format(val.real) + imag_str = formatter2.format(val.imag) + "MSG" + if val.imag < 0: + return real_str + imag_str.lstrip() + else: + return real_str + "MSG" + imag_str.lstrip() + else: + return formatter1.format(val) if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: + data = ([_val_formatter(val) for val in self[:PRINT_OPTS.edgeitems].tolist()] + + ['MSG'] + + [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems:].tolist()]) + else: + data = [_val_formatter(val) for val in self.tolist()] data_lines = [data[i:i + elements_per_line] for i in range(0, len(data), elements_per_line)] + lines = ['MSG'.join(line) for line in data_lines] + return 'MSG' + ('MSG' + 'MSG' + 'MSG' * (indent + 1)).join(lines) + 'MSG' +def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None): + dim = self.dim() if dim == 0: + return _scalar_str(self, formatter1, formatter2) if dim == 1: + return _vector_str(self, indent, summarize, formatter1, formatter2) if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: + slices = ([_tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1, formatter2) + for i in range(0, PRINT_OPTS.edgeitems)] + + ['MSG'] + + [_tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1, formatter2) + for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))]) + else: + slices = [_tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1, formatter2) + for i in range(0, self.size(0))] tensor_str = ('MSG' + 'MSG' * (dim - 1) + 'MSG' * (indent + 1)).join(slices) + return 'MSG' + tensor_str + 'MSG' def _tensor_str(self, indent): + if self.numel() == 0: + return 'MSG' if self.has_names(): + self = self.rename(None) summarize = self.numel() > PRINT_OPTS.threshold + if self.dtype is torch.float16 or self.dtype is torch.bfloat16: + self = self.float() if self.dtype.is_complex: + real_formatter = _Formatter(get_summarized_data(self.real) if summarize else self.real) + imag_formatter = _Formatter(get_summarized_data(self.imag) if summarize else self.imag) + return _tensor_str_with_formatter(self, indent, summarize, real_formatter, imag_formatter) + else: + formatter = _Formatter(get_summarized_data(self) if summarize else self) + return _tensor_str_with_formatter(self, indent, summarize, formatter) def _add_suffixes(tensor_str, suffixes, indent, force_newline): + tensor_strs = [tensor_str] + last_line_len = len(tensor_str) - tensor_str.rfind('MSG') + 1 + for suffix in suffixes: + suffix_len = len(suffix) + if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth: + tensor_strs.append('MSG' + 'MSG' * indent + suffix) + last_line_len = indent + suffix_len + force_newline = False + else: + tensor_strs.append('MSG' + suffix) + last_line_len += suffix_len + 2 + tensor_strs.append('MSG') + return 'MSG'.join(tensor_strs) +def get_summarized_data(self): + dim = self.dim() + if dim == 0: + return self + if dim == 1: + if self.size(0) > 2 * PRINT_OPTS.edgeitems: + return torch.cat((self[:PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems:])) + else: + return self + if self.size(0) > 2 * PRINT_OPTS.edgeitems: + start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)] + end = ([self[i] + for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))]) + return torch.stack([get_summarized_data(x) for x in (start + end)]) + else: + return torch.stack([get_summarized_data(x) for x in self]) def _str_intern(self): + prefix = 'MSG' + indent = len(prefix) + suffixes = [] + + + + + + + if self.device.type != torch._C._get_default_device()\ + or (self.device.type == 'MSG' and torch.cuda.current_device() != self.device.index): + suffixes.append('MSG''MSG'\'MSG') + _default_complex_dtype = torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat + has_default_dtype = self.dtype in (torch.get_default_dtype(), _default_complex_dtype, torch.int64, torch.bool) + if self.is_sparse: + suffixes.append('MSG' + str(tuple(self.shape))) + suffixes.append('MSG' + str(self._nnz())) + if not has_default_dtype: + suffixes.append('MSG' + str(self.dtype)) + indices_prefix = 'MSG' + indices = self._indices().detach() + indices_str = _tensor_str(indices, indent + len(indices_prefix)) + if indices.numel() == 0: + indices_str += 'MSG' + str(tuple(indices.shape)) + values_prefix = 'MSG' + values = self._values().detach() + values_str = _tensor_str(values, indent + len(values_prefix)) + if values.numel() == 0: + values_str += 'MSG' + str(tuple(values.shape)) + tensor_str = indices_prefix + indices_str + 'MSG' + 'MSG' * indent + values_prefix + values_str + 'MSG' + elif self.is_quantized: + suffixes.append('MSG' + str(tuple(self.shape))) + if not has_default_dtype: + suffixes.append('MSG' + str(self.dtype)) + suffixes.append('MSG' + str(self.qscheme())) + if self.qscheme() == torch.per_tensor_affine or self.qscheme() == torch.per_tensor_symmetric: + suffixes.append('MSG' + str(self.q_scale())) + suffixes.append('MSG' + str(self.q_zero_point())) + elif self.qscheme() == torch.per_channel_affine or self.qscheme() == torch.per_channel_symmetric \ + or self.qscheme() == torch.per_channel_affine_float_qparams: + suffixes.append('MSG' + str(self.q_per_channel_scales())) + suffixes.append('MSG' + str(self.q_per_channel_zero_points())) + suffixes.append('MSG' + str(self.q_per_channel_axis())) + tensor_str = _tensor_str(self.dequantize(), indent) + else: + if self.is_meta: + suffixes.append('MSG' + str(tuple(self.shape))) + if self.dtype != torch.get_default_dtype(): + suffixes.append('MSG' + str(self.dtype)) + tensor_str = 'MSG' + else: + if self.numel() == 0 and not self.is_sparse: + if self.dim() != 1: + suffixes.append('MSG' + str(tuple(self.shape))) if self.dtype != torch.get_default_dtype(): + suffixes.append('MSG' + str(self.dtype)) + tensor_str = 'MSG' + else: + if not has_default_dtype: + suffixes.append('MSG' + str(self.dtype)) if self.layout != torch.strided: + tensor_str = _tensor_str(self.to_dense(), indent) + else: + tensor_str = _tensor_str(self, indent) if self.layout != torch.strided: + suffixes.append('MSG' + str(self.layout)) if self.grad_fn is not None: + name = type(self.grad_fn).__name__ + if name == 'MSG': + name = self.grad_fn.name().rsplit('MSG', 1)[-1] + suffixes.append('MSG'.format(name)) + elif self.requires_grad: + suffixes.append('MSG') if self.has_names(): + suffixes.append('MSG'.format(self.names)) return _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse) def _str(self): + with torch.no_grad(): + return _str_intern(self) +import re import torch._C +from torch._C import _add_docstr as add_docstr +def parse_kwargs(desc): + + + regx = re.compile(r"MSG") + kwargs = [section.strip() for section in regx.split(desc)] + kwargs = [section for section in kwargs if len(section) > 0] + return {desc.split('MSG')[0]: desc for desc in kwargs} +def merge_dicts(*dicts): + return {x: d[x] for d in dicts for x in d} +common_args = parse_kwargs() reduceops_common_args = merge_dicts(common_args, parse_kwargs()) multi_dim_common = merge_dicts(reduceops_common_args, parse_kwargs(), {'MSG': }) single_dim_common = merge_dicts(reduceops_common_args, parse_kwargs(), {'MSG': }) factory_common_args = merge_dicts(common_args, parse_kwargs()) factory_like_common_args = parse_kwargs() factory_data_common_args = parse_kwargs() tf32_notes = { + "MSG": +} add_docstr(torch.abs, r + r.format(**common_args)) add_docstr(torch.absolute, + r.format(**common_args)) add_docstr(torch.acos, r + r.format(**common_args)) add_docstr(torch.arccos, r) add_docstr(torch.acosh, r + r.format(**common_args)) add_docstr(torch.arccosh, r.format(**common_args)) add_docstr(torch.add, r.format(**common_args)) add_docstr(torch.addbmm, + r + r.format(**common_args, **tf32_notes)) add_docstr(torch.addcdiv, r + r.format(**common_args)) add_docstr(torch.addcmul, + r + r.format(**common_args)) add_docstr(torch.addmm, + r + r.format(**common_args, **tf32_notes)) add_docstr(torch.addmv, + r + r.format(**common_args)) add_docstr(torch.addr, + r + r.format(**common_args)) add_docstr(torch.allclose, + r + r) add_docstr(torch.angle, + r + r.format(**common_args)) add_docstr(torch.as_strided, + r.format(**common_args)) add_docstr(torch.as_tensor, + r.format(**factory_data_common_args)) add_docstr(torch.asin, r + r.format(**common_args)) add_docstr(torch.arcsin, r) add_docstr(torch.asinh, + r + r.format(**common_args)) add_docstr(torch.arcsinh, r) add_docstr(torch.atan, r + r.format(**common_args)) add_docstr(torch.arctan, r) add_docstr(torch.atan2, + r.format(**common_args)) add_docstr(torch.atanh, r + r.format(**common_args)) add_docstr(torch.arctanh, r) add_docstr(torch.baddbmm, + r + r.format(**common_args, **tf32_notes)) add_docstr(torch.bernoulli, + r + r.format(**common_args)) add_docstr(torch.bincount, + r) add_docstr(torch.bitwise_not, + r.format(**common_args)) +add_docstr(torch.bmm, + r + r.format(**common_args, **tf32_notes)) add_docstr(torch.bitwise_and, + r.format(**common_args)) add_docstr(torch.bitwise_or, + r.format(**common_args)) add_docstr(torch.bitwise_xor, + r.format(**common_args)) add_docstr(torch.stack, + r.format(**common_args)) add_docstr(torch.hstack, + r.format(**common_args)) add_docstr(torch.vstack, + r.format(**common_args)) add_docstr(torch.dstack, + r.format(**common_args)) add_docstr(torch.chunk, + r) add_docstr(torch.unsafe_chunk, + r) add_docstr(torch.unsafe_split, + r) add_docstr(torch.can_cast, + r) add_docstr(torch.cat, + r.format(**common_args)) add_docstr(torch.ceil, + r + r.format(**common_args)) add_docstr(torch.real, + r.format(**common_args)) add_docstr(torch.imag, + r.format(**common_args)) add_docstr(torch.view_as_real, + r.format(**common_args)) add_docstr(torch.view_as_complex, + r.format(**common_args)) add_docstr(torch.reciprocal, + r + r.format(**common_args)) add_docstr(torch.cholesky, r) add_docstr(torch.cholesky_solve, r) add_docstr(torch.cholesky_inverse, r) add_docstr(torch.clone, r.format(**common_args)) add_docstr(torch.clamp, r + r.format(**common_args)) add_docstr(torch.clip, r.format(**common_args)) add_docstr(torch.complex, + r) add_docstr(torch.polar, + r + r) add_docstr(torch.conj, + r + r.format(**common_args)) add_docstr(torch.cos, + r + r.format(**common_args)) add_docstr(torch.cosh, + r + r.format(**common_args)) add_docstr(torch.cross, + r.format(**common_args)) add_docstr(torch.logcumsumexp, + r.format(**reduceops_common_args)) add_docstr(torch.cummax, + r.format(**reduceops_common_args)) add_docstr(torch.cummin, + r.format(**reduceops_common_args)) add_docstr(torch.cumprod, + r.format(**reduceops_common_args)) add_docstr(torch.cumsum, + r.format(**reduceops_common_args)) add_docstr(torch.count_nonzero, + r.format(**reduceops_common_args)) add_docstr(torch.dequantize, + r) add_docstr(torch.diag, + r.format(**common_args)) add_docstr(torch.diag_embed, + r.format(**common_args)) +add_docstr(torch.diagflat, + r.format(**common_args)) add_docstr(torch.diagonal, + r.format(**common_args)) add_docstr(torch.digamma, + r + r.format(**common_args)) +add_docstr(torch.dist, + r.format(**common_args)) add_docstr(torch.div, r.format(**common_args)) add_docstr(torch.divide, r) add_docstr(torch.dot, + r) add_docstr(torch.vdot, + r.format(**common_args)) add_docstr(torch.eig, + r) add_docstr(torch.eq, r.format(**common_args)) add_docstr(torch.equal, + r) add_docstr(torch.erf, + r + r.format(**common_args)) add_docstr(torch.erfc, + r + r.format(**common_args)) add_docstr(torch.erfinv, + r + r.format(**common_args)) add_docstr(torch.exp, + r + r.format(**common_args)) add_docstr(torch.exp2, + r + r.format(**common_args)) add_docstr(torch.expm1, + r + r.format(**common_args)) +add_docstr(torch.eye, + r.format(**factory_common_args)) add_docstr(torch.floor, + r + r.format(**common_args)) add_docstr(torch.floor_divide, r + r.format(**common_args)) add_docstr(torch.fmod, + r.format(**common_args)) add_docstr(torch.frac, + r) add_docstr(torch.from_numpy, + r) add_docstr(torch.flatten, + r.format(**common_args)) +add_docstr(torch.gather, + r + r) +add_docstr(torch.gcd, + r.format(**common_args)) add_docstr(torch.ge, r + r.format(**common_args)) add_docstr(torch.greater_equal, r) add_docstr(torch.geqrf, + r) add_docstr(torch.outer, r) add_docstr(torch.ger, + r) add_docstr(torch.solve, + r) add_docstr(torch.get_default_dtype, + r) add_docstr(torch.get_num_threads, + r) add_docstr(torch.get_num_interop_threads, + r) add_docstr(torch.gt, r + r.format(**common_args)) add_docstr(torch.greater, r) add_docstr(torch.histc, + r.format(**common_args)) add_docstr(torch.hypot, + r + r.format(**common_args)) add_docstr(torch.i0, + r + r.format(**common_args)) add_docstr(torch.index_select, + r.format(**common_args)) add_docstr(torch.inverse, + r.format(**common_args)) add_docstr(torch.isinf, r) add_docstr(torch.isposinf, + r.format(**common_args)) add_docstr(torch.isneginf, + r.format(**common_args)) add_docstr(torch.isclose, r + r) add_docstr(torch.isfinite, r.format(**common_args)) add_docstr(torch.isnan, r.format(**common_args)) add_docstr(torch.isreal, r.format(**common_args)) add_docstr(torch.is_floating_point, r.format(**common_args)) add_docstr(torch.is_complex, r.format(**common_args)) add_docstr(torch.is_nonzero, r.format(**common_args)) add_docstr(torch.kthvalue, + r.format(**single_dim_common)) add_docstr(torch.lcm, + r.format(**common_args)) add_docstr(torch.le, r + r.format(**common_args)) add_docstr(torch.less_equal, r) add_docstr(torch.lerp, + r + r.format(**common_args)) add_docstr(torch.lgamma, + r + .format(**common_args)) +add_docstr(torch.linspace, r + .format(**factory_common_args)) add_docstr(torch.log, + r + r.format(**common_args)) add_docstr(torch.log10, + r + r.format(**common_args)) add_docstr(torch.log1p, + r + r.format(**common_args)) add_docstr(torch.log2, + r + r.format(**common_args)) add_docstr(torch.logaddexp, + r.format(**common_args)) add_docstr(torch.logaddexp2, + r.format(**common_args)) add_docstr(torch.logical_and, + r.format(**common_args)) add_docstr(torch.logical_not, + r.format(**common_args)) add_docstr(torch.logical_or, + r.format(**common_args)) add_docstr(torch.logical_xor, + r.format(**common_args)) +add_docstr(torch.logspace, + r + .format(**factory_common_args)) add_docstr(torch.logsumexp, + r.format(**multi_dim_common)) add_docstr(torch.lstsq, + r) add_docstr(torch.lt, r + r.format(**common_args)) add_docstr(torch.less, r) add_docstr(torch.lu_solve, + r.format(**common_args)) add_docstr(torch.masked_select, + r.format(**common_args)) add_docstr(torch.matrix_rank, + r) add_docstr(torch.matrix_power, + r.format(**common_args)) add_docstr(torch.matrix_exp, + r + r.format(**common_args)) add_docstr(torch.max, + r.format(**single_dim_common)) add_docstr(torch.maximum, r.format(**common_args)) add_docstr(torch.amax, + r.format(**multi_dim_common)) add_docstr(torch.argmax, + r.format(**single_dim_common)) add_docstr(torch.mean, + r.format(**multi_dim_common)) add_docstr(torch.median, + r.format(**single_dim_common)) add_docstr(torch.quantile, + r.format(**single_dim_common)) add_docstr(torch.nanquantile, + r.format(**single_dim_common)) add_docstr(torch.min, + r.format(**single_dim_common)) add_docstr(torch.minimum, r.format(**common_args)) add_docstr(torch.amin, + r.format(**multi_dim_common)) add_docstr(torch.argmin, + r.format(**single_dim_common)) add_docstr(torch.mm, + r.format(**common_args, **tf32_notes)) add_docstr(torch.matmul, + r.format(**common_args, **tf32_notes)) add_docstr(torch.mode, + r.format(**single_dim_common)) add_docstr(torch.mul, r + r + r.format(**common_args)) add_docstr(torch.multiply, r.format(**common_args)) add_docstr(torch.multinomial, + r.format(**common_args)) add_docstr(torch.mv, + r.format(**common_args)) add_docstr(torch.mvlgamma, + r) add_docstr(torch.movedim, r.format(**common_args)) add_docstr(torch.narrow, + r) add_docstr(torch.ne, r + r.format(**common_args)) add_docstr(torch.not_equal, r) add_docstr(torch.neg, + r + r.format(**common_args)) add_docstr(torch.negative, + r.format(**common_args)) add_docstr(torch.nextafter, + r.format(**common_args)) add_docstr(torch.nonzero, + r.format(**common_args)) add_docstr(torch.normal, + r.format(**common_args)) add_docstr(torch.numel, + r.format(**common_args)) +add_docstr(torch.ones, + r.format(**factory_common_args)) +add_docstr(torch.ones_like, + r.format(**factory_like_common_args)) add_docstr(torch.orgqr, + r) add_docstr(torch.ormqr, + r) add_docstr(torch.poisson, + r.format(**common_args)) add_docstr(torch.polygamma, + r + .format(**common_args)) add_docstr(torch.pow, + r + r.format(**common_args)) add_docstr(torch.prod, + r.format(**single_dim_common)) add_docstr(torch.promote_types, + r) add_docstr(torch.qr, + r) add_docstr(torch.rad2deg, + r.format(**common_args)) add_docstr(torch.deg2rad, + r.format(**common_args)) add_docstr(torch.heaviside, + r + r.format(**common_args)) add_docstr(torch.rand, + r.format(**factory_common_args)) add_docstr(torch.rand_like, + r.format(**factory_like_common_args)) add_docstr(torch.randint, + .format(**factory_common_args)) add_docstr(torch.randint_like, + .format(**factory_like_common_args)) add_docstr(torch.randn, + r.format(**factory_common_args)) add_docstr(torch.randn_like, + r.format(**factory_like_common_args)) add_docstr(torch.randperm, + r.format(**factory_common_args)) add_docstr(torch.tensor, + r.format(**factory_data_common_args)) add_docstr(torch.range, + r + r.format(**factory_common_args)) add_docstr(torch.arange, + r + r.format(**factory_common_args)) add_docstr(torch.remainder, + r.format(**common_args)) add_docstr(torch.renorm, + r.format(**common_args)) add_docstr(torch.reshape, + r) +add_docstr(torch.result_type, + r) +add_docstr(torch.round, + r.format(**common_args)) add_docstr(torch.rsqrt, + r + r.format(**common_args)) add_docstr(torch.set_flush_denormal, + r) add_docstr(torch.set_num_threads, r) add_docstr(torch.set_num_interop_threads, r) add_docstr(torch.sigmoid, r + r.format(**common_args)) add_docstr(torch.logit, + r + r.format(**common_args)) add_docstr(torch.sign, + r + r.format(**common_args)) add_docstr(torch.signbit, + r.format(**common_args)) add_docstr(torch.sgn, + r + r.format(**common_args)) add_docstr(torch.sin, + r + r.format(**common_args)) add_docstr(torch.sinh, + r + r.format(**common_args)) add_docstr(torch.sort, + r.format(**common_args)) add_docstr(torch.argsort, + r.format(**common_args)) add_docstr(torch.sparse_coo_tensor, + r.format(**factory_common_args)) add_docstr(torch.sqrt, + r + r.format(**common_args)) add_docstr(torch.square, + r.format(**common_args)) add_docstr(torch.squeeze, + r.format(**common_args)) add_docstr(torch.std, r.format(**multi_dim_common)) add_docstr(torch.std_mean, + r.format(**multi_dim_common)) add_docstr(torch.sub, r + r.format(**common_args)) add_docstr(torch.subtract, r) add_docstr(torch.sum, + r.format(**multi_dim_common)) add_docstr(torch.nansum, + r.format(**multi_dim_common)) add_docstr(torch.svd, + r) add_docstr(torch.symeig, + r) add_docstr(torch.t, + r.format(**common_args)) add_docstr(torch.flip, + r.format(**common_args)) add_docstr(torch.fliplr, + r.format(**common_args)) add_docstr(torch.flipud, + r.format(**common_args)) add_docstr(torch.roll, + r.format(**common_args)) add_docstr(torch.rot90, + r.format(**common_args)) add_docstr(torch.take, + r.format(**common_args)) add_docstr(torch.tan, + r + r.format(**common_args)) add_docstr(torch.tanh, + r + r.format(**common_args)) add_docstr(torch.topk, + r.format(**common_args)) add_docstr(torch.trace, + r) add_docstr(torch.transpose, + r.format(**common_args)) add_docstr(torch.triangular_solve, + r) add_docstr(torch.tril, + r + r.format(**common_args)) add_docstr(torch.tril_indices, + r + r.format(**factory_common_args)) add_docstr(torch.triu, + r + r.format(**common_args)) add_docstr(torch.triu_indices, + r + r.format(**factory_common_args)) add_docstr(torch.true_divide, r.format(**common_args)) add_docstr(torch.trunc, + r.format(**common_args)) add_docstr(torch.fix, + r.format(**common_args)) add_docstr(torch.unsqueeze, + r.format(**common_args)) add_docstr(torch.var, r.format(**multi_dim_common)) add_docstr(torch.var_mean, + r.format(**multi_dim_common)) add_docstr(torch.zeros, + r.format(**factory_common_args)) add_docstr(torch.zeros_like, + r.format(**factory_like_common_args)) add_docstr(torch.empty, + r.format(**factory_common_args)) add_docstr(torch.empty_like, + r.format(**factory_like_common_args)) add_docstr(torch.empty_strided, + r.format(**factory_common_args)) add_docstr(torch.full, r.format(**factory_common_args)) add_docstr(torch.full_like, + .format(**factory_like_common_args)) add_docstr(torch.det, + r) add_docstr(torch.where, + r) add_docstr(torch.logdet, + r) add_docstr(torch.slogdet, + r) add_docstr(torch.pinverse, + r) add_docstr(torch.fft, r) add_docstr(torch.ifft, r) add_docstr(torch.rfft, r) +add_docstr(torch.irfft, r) +add_docstr(torch.hann_window, + + r + r.format(**factory_common_args)) +add_docstr(torch.hamming_window, + + r + r.format(**factory_common_args)) +add_docstr(torch.bartlett_window, + + r + r.format(**factory_common_args)) +add_docstr(torch.blackman_window, + + r + r.format(**factory_common_args)) +add_docstr(torch.kaiser_window, + r + r.format(**factory_common_args)) +add_docstr(torch.vander, + + r.format(**factory_common_args)) +add_docstr(torch.unbind, + r) +add_docstr(torch.combinations, + r) add_docstr(torch.trapz, + r) add_docstr(torch.repeat_interleave, + r.format(**common_args)) +add_docstr(torch.quantize_per_tensor, + r) add_docstr(torch.quantize_per_channel, + r) add_docstr(torch.Generator, + r) +add_docstr(torch.Generator.set_state, + r) +add_docstr(torch.Generator.get_state, + r) +add_docstr(torch.Generator.manual_seed, + r) +add_docstr(torch.Generator.initial_seed, + r) +add_docstr(torch.Generator.seed, + r) +add_docstr(torch.Generator.device, + r) add_docstr(torch.searchsorted, + r) add_docstr(torch.bucketize, + r) +import torch +import torch._six +from typing import Optional +import warnings +from collections import defaultdict +import sys +import traceback +def _type(self, dtype=None, non_blocking=False, **kwargs): + + non_blocking = _get_async_or_non_blocking('MSG', non_blocking, kwargs) + if dtype is None: + return self.__module__ + 'MSG' + self.__class__.__name__ if isinstance(dtype, str): + dtype = _import_dotted_name(dtype) + if dtype == type(self): + return self + if self.is_sparse: + if not dtype.is_sparse: + raise RuntimeError("MSG") + new_module_name = dtype.__module__.replace('MSG', 'MSG') + new_values_type_name = new_module_name + 'MSG' + dtype.__name__ + new_values = torch._values(self).type(new_values_type_name, non_blocking) + new_indices_type_name = new_module_name + 'MSG' + new_indices = torch._indices(self).type(new_indices_type_name, non_blocking) + return dtype(new_indices, new_values, self.size()) + if dtype.is_sparse: + raise RuntimeError("MSG") + return dtype(self.size()).copy_(self, non_blocking) +def _cuda(self, device=None, non_blocking=False, **kwargs): + + non_blocking = _get_async_or_non_blocking('MSG', non_blocking, kwargs) + if self.is_cuda: + if device is None: + device = torch.cuda.current_device() + if self.get_device() == device: + return self + else: + if device is None: + device = -1 + with torch.cuda.device(device): + if self.is_sparse: + new_type = getattr(torch.cuda.sparse, self.__class__.__name__) + indices = torch._indices(self).cuda(device, non_blocking) + values = torch._values(self).cuda(device, non_blocking) + return new_type(indices, values, self.size()) + else: + new_type = getattr(torch.cuda, self.__class__.__name__) + return new_type(self.size()).copy_(self, non_blocking) +def _get_async_or_non_blocking(function_name, non_blocking, kwargs): + if not kwargs: + return non_blocking + if len(kwargs) != 1 or 'MSG' not in kwargs: + message = "MSG" + argument = list(kwargs.keys()).pop() + raise TypeError(message.format(function_name, argument)) + warnings.warn("MSG") + return kwargs['MSG'] def _rebuild_tensor(storage, storage_offset, size, stride): + + t = torch.tensor([], dtype=storage.dtype, device=storage.device) + return t.set_(storage, storage_offset, size, stride) +def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): + tensor = _rebuild_tensor(storage, storage_offset, size, stride) + tensor.requires_grad = requires_grad + + + + tensor._backward_hooks = backward_hooks + return tensor +_sparse_tensors_to_validate = [] def _validate_loaded_sparse_tensors(): + try: + for t in _sparse_tensors_to_validate: + torch._validate_sparse_coo_tensor_args(t._indices(), t._values(), + t.size()) + finally: + _sparse_tensors_to_validate.clear() def _rebuild_sparse_tensor(layout, data): + if layout == torch.sparse_coo: + indices, values, size = data + result = torch._sparse_coo_tensor_unsafe(indices, values, size) + _sparse_tensors_to_validate.append(result) + return result raise NotImplementedError("MSG" % (layout)) +def _rebuild_xla_tensor(data, dtype, device, requires_grad): + tensor = torch.from_numpy(data).to(dtype=dtype, device=device) + tensor.requires_grad = requires_grad + return tensor +def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks): + qscheme = quantizer_params[0] + if qscheme == torch.per_tensor_affine: + _, scale, zero_point = quantizer_params + tensor = torch._empty_affine_quantized(size, scale=scale, zero_point=zero_point, dtype=storage.dtype) + elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams): + _, scales, zero_points, axis = quantizer_params + if type(scales) is list and type(zero_points) is list: + if qscheme == torch.per_channel_affine: + scales = torch.tensor(scales, dtype=torch.double) + zero_points = torch.tensor(zero_points, dtype=torch.long) + else: + scales = torch.tensor(scales, dtype=torch.float) + zero_points = torch.tensor(zero_points, dtype=torch.float) + tensor = torch._empty_per_channel_affine_quantized( + size, scales=scales, zero_points=zero_points, axis=axis, dtype=storage.dtype) + else: + raise RuntimeError("MSG".format(qscheme)) + tensor.set_(storage, storage_offset, size, stride) + tensor.requires_grad = requires_grad + + + + tensor._backward_hooks = backward_hooks + return tensor def _rebuild_parameter(data, requires_grad, backward_hooks): + param = torch.nn.Parameter(data, requires_grad) + + + + param._backward_hooks = backward_hooks return param +def _import_dotted_name(name): + components = name.split('MSG') + obj = __import__(components[0]) + for component in components[1:]: + obj = getattr(obj, component) + return obj def _accumulate(iterable, fn=lambda x, y: x + y): + 'MSG' + + + it = iter(iterable) + try: + total = next(it) + except StopIteration: + return + yield total + for element in it: + total = fn(total, element) + yield total +def _flatten_dense_tensors(tensors): + + if len(tensors) == 1: + return tensors[0].contiguous().view(-1) + flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) + return flat +def _flatten_sparse_tensors(tensors): + + flat_indices = _flatten_dense_tensors([torch._indices(t) for t in tensors]) + flat_values = _flatten_dense_tensors([torch._values(t) for t in tensors]) + return flat_indices, flat_values +def _unflatten_dense_tensors(flat, tensors): + + outputs = [] + offset = 0 + for tensor in tensors: + numel = tensor.numel() + outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) + offset += numel + return tuple(outputs) +def _unflatten_sparse_tensors(flat, tensors): + + flat_indices, flat_values = flat + indices = _unflatten_dense_tensors(flat_indices, [torch._indices(t) for t in tensors]) + values = _unflatten_dense_tensors(flat_values, [torch._values(t) for t in tensors]) + outputs = [] + for t, i, v in zip(tensors, indices, values): + outputs.append(t.new(i, v, t.size())) + return tuple(outputs) +def _reorder_tensors_as(tensors, ordered_tensors): + + type_dict = defaultdict(list) + for tensor in tensors: + type_dict[tensor.type()].append(tensor) + type_dict = {t: iter(coll) for t, coll in type_dict.items()} + return tuple(next(type_dict[tensor.type()]) for tensor in ordered_tensors) +def _take_tensors(tensors, size_limit): + + buf_dict = defaultdict(lambda: [[], 0]) + for tensor in tensors: + t = tensor.type() + if tensor.is_sparse: + indices = torch._indices(tensor) + values = torch._values(tensor) + size = indices.numel() * indices.element_size() + values.numel() * values.element_size() + else: + size = tensor.numel() * tensor.element_size() + buf_and_size = buf_dict[t] + if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: + yield buf_and_size[0] + buf_and_size = buf_dict[t] = [[], 0] + buf_and_size[0].append(tensor) + buf_and_size[1] += size + for buf, _ in buf_dict.values(): + if len(buf) > 0: + yield buf +def annotate(ret, **kwargs): + def dec(fun): + fun.__annotations__ = dict(kwargs) + fun.__annotations__['MSG'] = ret + return fun + return dec +class KeyErrorMessage(str): + r + def __repr__(self): + return self +class ExceptionWrapper(object): + r + def __init__(self, exc_info=None, where="MSG"): + if exc_info is None: + exc_info = sys.exc_info() + self.exc_type = exc_info[0] + self.exc_msg = "MSG".join(traceback.format_exception(*exc_info)) + self.where = where def reraise(self): + r + msg = "MSG".format( + self.exc_type.__name__, self.where, self.exc_msg) + if self.exc_type == KeyError: + msg = KeyErrorMessage(msg) + elif getattr(self.exc_type, "MSG", None): + raise self.exc_type(message=msg) + raise self.exc_type(msg) +def _get_available_device_type(): + if torch.cuda.is_available(): + return "MSG" + + return None +def _get_device_attr(get_member): + device_type = _get_available_device_type() + if device_type.lower() == "MSG": + return get_member(torch.cuda) + + return None +def _get_current_device_index(): + + return _get_device_attr(lambda m: m.current_device()) +def _get_all_device_indices(): + + return _get_device_attr(lambda m: list(range(m.device_count()))) +def _get_devices_properties(device_ids): + + return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids] +def _get_device_index(device, optional=False, allow_cpu=False) -> int: + r + if isinstance(device, str): + device = torch.device(device) + device_idx: Optional[int] + device_idx = None + if isinstance(device, torch.device): + if not allow_cpu and device.type == 'MSG': + raise ValueError('MSG'.format(device)) + device_idx = -1 if device.type == 'MSG' else device.index + if isinstance(device, int): + device_idx = device + if device_idx is None: + if optional: + device_idx = _get_current_device_index() + else: + raise ValueError('MSG' + 'MSG'.format(device)) + return device_idx import os +import inspect +import tempfile +if os.path.basename(os.path.dirname(__file__)) == 'MSG': + torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) +else: + torch_parent = os.path.dirname(os.path.dirname(__file__)) +def get_file_path(*path_components): + return os.path.join(torch_parent, *path_components) +def get_file_path_2(*path_components): + return os.path.join(*path_components) +def get_writable_path(path): + if os.access(path, os.W_OK): + return path + return tempfile.mkdtemp(suffix=os.path.basename(path)) def prepare_multiprocessing_environment(path): + pass +def resolve_library_path(path): + return os.path.realpath(path) +def get_source_lines_and_file(obj, error_msg=None): + + filename = None + try: + filename = inspect.getsourcefile(obj) + sourcelines, file_lineno = inspect.getsourcelines(obj) + except OSError as e: + msg = (f"MSG" + "MSG" + "MSG") + if error_msg: + msg += 'MSG' + error_msg + raise OSError(msg) from e return sourcelines, file_lineno, filename +TEST_MASTER_ADDR = 'MSG' +TEST_MASTER_PORT = 29500 +USE_GLOBAL_DEPS = True +USE_RTLD_GLOBAL_WITH_LIBTORCH = False import torch +import sys +import types +class VFModule(types.ModuleType): + vf: types.ModuleType def __init__(self, name): + super(VFModule, self).__init__(name) + self.vf = torch._C._VariableFunctions def __getattr__(self, attr): + return getattr(self.vf, attr) +sys.modules[__name__] = VFModule(__name__) +import torch +import functools +from torch import Tensor +from typing import Any, Callable, Optional, Tuple, Union +import warnings in_dims_t = Union[int, Tuple[Optional[int], ...]] +out_dims_t = Union[int, Tuple[int, ...]] +def _validate_and_get_batch_size( + in_dims_as_tuple: Tuple[Optional[int], ...], + args: Tuple) -> int: + batch_sizes = [arg.size(in_dim) for in_dim, arg in zip(in_dims_as_tuple, args) + if in_dim is not None] + if batch_sizes and any([size != batch_sizes[0] for size in batch_sizes]): + raise ValueError( + f'MSG' + f'MSG') + return batch_sizes[0] +def _check_args_can_be_mapped_with_in_dims( + in_dims_as_tuple: Tuple[Optional[int], ...], + args: Tuple, + func: Callable, + in_dims: in_dims_t) -> None: + for idx, (in_dim, arg) in enumerate(zip(in_dims_as_tuple, args)): + if in_dim is None: + continue + if not isinstance(in_dim, int): + raise ValueError( + f'MSG' + f'MSG' + f'MSG' + f'MSG') + if not isinstance(arg, Tensor): + raise ValueError( + f'MSG' + f'MSG' + f'MSG'ed over. 'MSG'If you were trying to vmap over a Tensor inside a Python 'MSG'collection in `inputs`, we do not yet support that; otherwise, 'MSG'use None as the respective in_dim for input {idx}.'MSG'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): Got in_dim={in_dim} 'MSG'for input {idx}, but input {idx} is a Tensor of dimensionality 'MSG'{arg.dim()} so expected in_dim to satisfy 0 <= in_dim < {arg.dim()}.'MSG'vmap({_get_name(func)}, in_dims={in_dims}, ...): expected `in_dims` to 'MSG'be int or tuple, got: {type(in_dims)}.'MSG'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): expected 'MSG'one `in_dim` per input (got {len(args)} inputs) of {_get_name(func)}'MSG'vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add 'MSG'inputs, or you are trying to vmap over a function with no inputs. 'MSG'The latter is unsupported.'MSG'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must 'MSG'have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.'MSG'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return 'MSG'Tensors, got type {type(outputs)} as the return.'MSG'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return 'MSG'Tensors, got type {type(output)} for return {idx}.'MSG'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be 'MSG'an int or a tuple of int representing where in the outputs the 'MSG'vmapped dimension should appear.'MSG'__name__'MSG'torch.vmap is an experimental prototype that is subject to 'MSG'change and/or deletion. Please use at your own risk.'MSG'typename'MSG'is_tensor'MSG'is_storage'MSG'set_default_tensor_type'MSG'set_rng_state'MSG'get_rng_state'MSG'manual_seed'MSG'initial_seed'MSG'seed'MSG'save'MSG'load'MSG'set_printoptions'MSG'chunk'MSG'split'MSG'stack'MSG'matmul'MSG'no_grad'MSG'enable_grad'MSG'rand'MSG'randn'MSG'DoubleStorage'MSG'FloatStorage'MSG'LongStorage'MSG'IntStorage'MSG'ShortStorage'MSG'CharStorage'MSG'ByteStorage'MSG'BoolStorage'MSG'DoubleTensor'MSG'FloatTensor'MSG'LongTensor'MSG'IntTensor'MSG'ShortTensor'MSG'CharTensor'MSG'ByteTensor'MSG'BoolTensor'MSG'Tensor'MSG'lobpcg'MSG'set_deterministic'MSG'is_deterministic'MSG'win32'MSG'ProgramFiles'MSG'C:\\Program Files'MSG'Library'MSG'bin'MSG'lib'MSG'Library'MSG'bin'MSG''MSG'nvToolsExt64_1.dll'MSG'NVTOOLSEXT_PATH'MSG'NVIDIA Corporation'MSG'NvToolsExt'MSG'bin'MSG'x64'MSG''MSG'cudart64*.dll'MSG'.'MSG'_'MSG'CUDA_PATH_V'MSG'NVIDIA GPU Computing Toolkit'MSG'CUDA'MSG'v'MSG'bin'MSG''MSG'kernel32.dll'MSG'AddDllDirectory'MSG' Error adding "MSG" to the DLL directories.'MSG'vcruntime140.dll'MSG'msvcp140.dll'MSG'9.2'MSG'10.0'MSG'vcruntime140_1.dll'MSG'*.dll'MSG' Error loading "MSG" or one of its dependencies.'MSG'PATH'MSG';'MSG'PATH'MSG' Error loading "MSG" or one of its dependencies.'MSG'Windows'MSG'libtorch_global_deps'MSG'.dylib'MSG'Darwin'MSG'.so'MSG'lib'MSG'TORCH_USE_RTLD_GLOBAL'MSG'Windows'MSG'RTLD_GLOBAL'MSG'RTLD_LAZY'MSG'_'MSG'Base'MSG''MSG''MSG'__module__'MSG'builtins'MSG'__builtin__'MSG'.'MSG'__qualname__'MSG'__name__'MSG'Windows'MSG'torch'MSG'bin'MSG'torch_shm_manager'MSG'torch'MSG'utf-8'MSG'__'MSG'Anomaly Detection has been enabled. 'MSG'This mode will increase the runtime 'MSG'and should only be enabled for debugging.'MSG'mark_shared_storage is deprecated. 'MSG'Tensors with shared storages are automatically tracked. Note 'MSG'that calls to `set_()` are not tracked'MSG'forward'MSG'Backward'MSG'_forward_cls'MSG't know how to process "MSG"an input object of type "MSG". Accepted types: "MSG", or lists/tuples of them"MSG""MSG"_jit_unwrap"MSG"Auto nesting doesn'MSG's Values or None"MSG"Tensors"MSG"Tensors (permissive)"MSG"Tensors or None"MSG"Tensors"MSG"The {} given to {} must be either a Tensor or a tuple of Tensors but the"MSG" value at index {} has type {}."MSG"The {} given to {} must be either a Tensor or a tuple of Tensors but the"MSG" given {} has type {}."MSG"v is a tuple of invalid length: should be {} but got {}."MSG"The given v should contain a single Tensor."MSG""MSG"Entry {} in "MSG"{}v has invalid size: should be {} but got {}."MSG"outputs"MSG"grad_inputs"MSG"jacobian"MSG"hessian"MSG"Invalid input_type to _check_requires_grad"MSG"The output of the user-provided function is independent of input {}."MSG" This is not allowed in strict mode."MSG"hessian"MSG"The hessian of the user-provided function with respect to input {}"MSG" is independent of the input. This is not allowed in strict mode."MSG" You should ensure that your function is thrice differentiable and that"MSG" the hessian depends on the inputs."MSG"jacobian"MSG"While computing the hessian, found that the jacobian of the user-provided"MSG" function with respect to input {} is independent of the input. This is not"MSG" allowed in strict mode. You should ensure that your function is twice"MSG" differentiable and that the jacobian depends on the inputs (this would be"MSG" violated by a linear function for example)."MSG"grad_inputs"MSG"The gradient with respect to input {} is independent of the inputs of the"MSG" user-provided function. This is not allowed in strict mode."MSG"Output {} of the user-provided function does not require gradients."MSG" The outputs must be computed in a differentiable manner from the input"MSG" when running in strict mode."MSG"back"MSG"back_trick"MSG"double_back"MSG"double_back_trick"MSG"Invalid stage argument 'MSG' to _fill_in_zeros"MSG"back"MSG"The output of the user-provided function is independent of "MSG"input {}. This is not allowed in strict mode."MSG"back_trick"MSG"The gradient with respect to the input is independent of entry {}"MSG" in the grad_outputs when using the double backward trick to compute"MSG" forward mode gradients. This is not allowed in strict mode."MSG"double_back"MSG"The jacobian of the user-provided function is independent of "MSG"input {}. This is not allowed in strict mode."MSG"The hessian of the user-provided function is independent of "MSG"entry {} in the grad_jacobian. This is not allowed in strict "MSG"mode as it prevents from using the double backward trick to "MSG"replace forward mode AD."MSG"double"MSG"The jacobian of the user-provided function is independent of "MSG"input {}. This is not allowed in strict mode when create_graph=True."MSG"The hessian of the user-provided function is independent of "MSG"input {}. This is not allowed in strict mode when create_graph=True."MSG"inputs"MSG"vjp"MSG"outputs of the user-provided function"MSG"vjp"MSG"outputs"MSG"v"MSG"vjp"MSG"The vector v can only be None if the "MSG"user-provided function returns "MSG"a single Tensor with a single element."MSG"back"MSG"inputs"MSG"jvp"MSG"v"MSG"jvp"MSG"The vector v can only be None if the input to "MSG"the user-provided function is a single Tensor "MSG"with a single element."MSG"outputs of the user-provided function"MSG"jvp"MSG"outputs"MSG"grad_inputs"MSG"back_trick"MSG"inputs"MSG"jacobian"MSG"outputs of the user-provided function"MSG"jacobian"MSG"outputs"MSG"The jacobian of the user-provided function is "MSG"independent of input {}. This is not allowed in "MSG"strict mode when create_graph=True."MSG"Output {} of the user-provided function is "MSG"independent of input {}. This is not allowed in "MSG"strict mode."MSG"inputs"MSG"hessian"MSG"outputs of the user-provided function"MSG"hessian"MSG"outputs"MSG"The function given to hessian should return a single Tensor"MSG"The Tensor returned by the function given to hessian should contain a single element"MSG"jacobian"MSG"inputs"MSG"vhp"MSG"v"MSG"vhp"MSG"The vector v can only be None if the input to the user-provided function "MSG"is a single Tensor with a single element."MSG"outputs of the user-provided function"MSG"vhp"MSG"outputs"MSG"The function given to vhp should return a single Tensor"MSG"The Tensor returned by the function given to vhp should contain a single element"MSG"jacobian"MSG"double_back"MSG"inputs"MSG"hvp"MSG"v"MSG"hvp"MSG"The vector v can only be None if the input to the user-provided function "MSG"is a single Tensor with a single element."MSG"outputs of the user-provided function"MSG"hvp"MSG"outputs"MSG"The function given to hvp should return a single Tensor"MSG"The Tensor returned by the function given to hvp should contain a single element"MSG"jacobian"MSG"hessian"MSG"double_back_trick"MSG"Gradients failed to compare equal for grad output = 1j. "MSG"Gradients failed to compare equal for grad output = 1. "MSG"Backward"MSG" is not reentrant, i.e., running backward with same \ + input and grad_output multiple times gives different values, \ + although analytical gradient matches numerical gradient. \ + The tolerance for nondeterminism was {}."MSG"no Tensors requiring grad found in input"MSG"grad is incorrect type"MSG"check_undefined_grad=False"MSG"Notes about undefined output gradients"MSG"tools/autograd/derivatives.yaml"MSG"Notes about undefined output gradients"MSG"tools/autograd/derivatives.yaml"MSG"thread"MSG"There is already a CPU parent event for {}"MSG"["MSG"name"MSG"%s"MSG"ph"MSG"X"MSG"ts"MSG"dur"MSG"tid"MSG"pid"MSG"CPU functions"MSG"args"MSG" node_id:{evt.node_id}, thread_id:{evt.thread} "MSG"name"MSG"%s"MSG"ph"MSG"s"MSG"ts"MSG"tid"MSG"pid"MSG"CPU functions"MSG"id"MSG"cat"MSG"cpu_to_cuda"MSG"args"MSG"name"MSG"%s"MSG"ph"MSG"f"MSG"ts"MSG"tid"MSG"pid"MSG"CUDA functions"MSG"id"MSG"cat"MSG"cpu_to_cuda"MSG"args"MSG"name"MSG"%s"MSG"ph"MSG"X"MSG"ts"MSG"dur"MSG"tid"MSG"pid"MSG"CUDA functions"MSG"args"MSG"]"MSG""MSG"autograd profiler traces are not reentrant"MSG"can'MSG't finish running"MSG"_call_end_callbacks_on_future can only be called once."MSG"NVTX annotation context manager is not reentrant"MSG"Expected time_us == 0 but got {}"MSG"NaN"MSG"profiler::_record_function_enter"MSG"profiler::_record_function_exit"MSG"aten::is_leaf"MSG"aten::output_nr"MSG"aten::_version"MSG"autograd/__init__"MSG"_make_grads"MSG"autograd/__init__"MSG"backward"MSG"torch/tensor"MSG"backward"MSG"_internal/common_utils"MSG"prof_callable"MSG"_internal/common_utils"MSG"prof_func_call"MSG"_internal/common_utils"MSG"prof_meth_call"MSG"SELECT _id_ as id, value FROM StringTable"MSG"id"MSG"value"MSG""MSG""MSG""MSG""MSG""MSG""MSG"Self CPU time total: {}"MSG"CUDA time total: {}"MSG"Mismatch in shape: grad_output["MSG"] has a shape of "MSG" and output["MSG"] has a shape of "MSG"."MSG"For complex Tensors, both grad_output and output"MSG" are required to have the same dtype."MSG" Mismatch in dtype: grad_output["MSG"] has a dtype of "MSG" and output["MSG"] has a dtype of "MSG"."MSG"grad can be implicitly created only for scalar outputs"MSG"gradients can be either Tensors or None, but got "MSG"'MSG' is deprecated. Use 'MSG' instead."MSG"'MSG' and 'MSG' (deprecated) "MSG"arguments both passed to backward(). Please only "MSG"use 'MSG'."MSG"only_inputs argument is deprecated and is ignored now "MSG"(defaults to True). To accumulate gradient for other "MSG"parts of the graph, please use torch.autograd.backward."MSG"torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead"MSG"autograd initialization failed"MSG"requested resize to {} ({} elements in total), "MSG"but the given tensor has a size of {} ({} elements). "MSG"autograd'MSG't be loaded, is cuda enabled?"MSG"GPU:{device}"MSG"no processes are running"MSG"process {p.pid:>10d} uses {mem:>12.3f} MB GPU memory"MSG"\n"MSG"Inputs should be a collection of tensors"MSG"'MSG' and 'MSG' can not be both specified. 'MSG' is deprecated in "MSG"favor of 'MSG', taking in a single output tensor. The signature of reduce is: "MSG"reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)."MSG"nccl.reduce with an output tensor list is deprecated. "MSG"Please specify a single output tensor with argument 'MSG' instead instead."MSG"nccl.reduce with an output tensor list is deprecated. "MSG"Please specify a single output tensor."MSG"NVTX functions not installed. Are you sure you have a CUDA build?"MSG"gpustarttimestamp"MSG"gpuendtimestamp"MSG"gridsize3d"MSG"threadblocksize"MSG"streamid"MSG"enableonstart 0"MSG"conckerneltrace"MSG"HIP does not support profiler initialization!"MSG"supported CUDA profiler output modes are: key_value and csv"MSG"Tried to instantiate dummy base class {}"MSG"__init__"MSG"_cuda_isInBadFork"MSG" "MSG"To use CUDA with multiprocessing, you must use Python "MSG"3.4+ and the 'MSG' start method"MSG"To use CUDA with multiprocessing, you must use the "MSG"'MSG' start method"MSG"Cannot re-initialize CUDA in forked subprocess. "MSG"Torch not compiled with CUDA enabled"MSG"libcudart functions unavailable. It looks like you have a broken build?"MSG"CUDA call failed lazily at initialization with error: {str(e)}\n\n"MSG"CUDA call was originally invoked at:\n\n{orig_traceback}"MSG"Invalid device id"MSG""MSG"_"MSG" "MSG"-gencode compute=compute_{arch},code={kind}_{arch}"MSG"torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available. Disabling."MSG"cast_inputs"MSG"cast_inputs"MSG"stage"MSG"found_inf_per_device"MSG"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling."MSG"The growth factor must be > 1.0."MSG"The backoff factor must be < 1.0."MSG"This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."MSG"Attempted {} but _scale is None. "MSG"Attempted {} but _growth_tracker is None. "MSG"_growth_tracker initialized before _scale"MSG"outputs must be a Tensor or an iterable of Tensors"MSG"params"MSG"Attempting to unscale FP16 gradients."MSG"unscale_"MSG"stage"MSG"unscale_() has already been called on this optimizer since the last update()."MSG"stage"MSG"unscale_() is being called after step()."MSG"found_inf_per_device"MSG"stage"MSG"closure"MSG"Closure use is not currently supported if GradScaler is enabled."MSG"step"MSG"stage"MSG"step() has already been called since the last update()."MSG"_step_supports_amp_scaling"MSG"stage"MSG"stage"MSG"found_inf_per_device"MSG"No inf checks were recorded for this optimizer."MSG"found_inf_per_device"MSG"stage"MSG"update"MSG"new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."MSG"found_inf_per_device"MSG"No inf checks were recorded prior to update."MSG"scale"MSG"growth_factor"MSG"backoff_factor"MSG"growth_interval"MSG"_growth_tracker"MSG"The source state dict is empty, possibly because it was saved "MSG"from a disabled instance of GradScaler."MSG"scale"MSG"scale"MSG"growth_factor"MSG"backoff_factor"MSG"growth_interval"MSG"_growth_tracker"MSG"_growth_tracker"MSG"A GradScaler instance may only be pickled at the beginning "MSG"of an iteration, or at the end after scaler.update()."MSG"_check_inf_per_device"MSG"found_inf_per_device"MSG"found_inf_per_device"MSG"found_inf_per_device"MSG"undefined"MSG"gloo"MSG"nccl"MSG"mpi"MSG"tcp"MSG"Backend name must be a string, but got: {}"MSG"TCP backend has been deprecated. Please use "MSG"Gloo or MPI backend for collective operations "MSG"on CPU tensors."MSG"Invalid backend: 'MSG'"MSG"torch.distributed.reduce_op is deprecated, please use "MSG"torch.distributed.ReduceOp instead"MSG"group.WORLD does not have local rank to global "MSG"rank mapping"MSG"The given group does not exist"MSG"The global rank {rank} is not part of the group {group}"MSG"group.WORLD does not have local rank to global "MSG"rank mapping"MSG"The group rank is not part of the group"MSG"Default process group is not initialized"MSG"The given group does not exist"MSG"Invalid function argument. Expected parameter `{}` "MSG"to be of type torch.Tensor."MSG"Invalid function argument. Expected parameter `{}` "MSG"to be of type List[torch.Tensor]."MSG"Default process group has not been initialized, "MSG"please make sure to call init_process_group."MSG"Default process group has not been initialized, "MSG"please make sure to call init_process_group."MSG"Invalid process group specified"MSG"Expected timeout argument to be of type"MSG"datetime.timedelta"MSG"trying to initialize the default process group "MSG"twice!"MSG"Cannot specify both init_method and store."MSG"env://"MSG"For MPI backend, world_size ({}) and rank ({}) "MSG"are ignored since they are assigned by the "MSG"MPI runtime."MSG"The specified group name has already been "MSG"created, please use a different group name"MSG"Expected timeout argument to be of type"MSG"datetime.timedelta"MSG"Distributed package doesn'MSG't have NCCL "MSG"built in"MSG"Invalid process group specified"MSG"tensor"MSG"tensor"MSG"tensor"MSG"tensor"MSG"tensor"MSG"tensor"MSG"tensor"MSG"tensor"MSG"cpu"MSG"cpu"MSG"cpu"MSG"cpu"MSG"tensor_list"MSG"tensor"MSG"tensor_list"MSG"Invalid function argument: "MSG"output_tensor_lists should be a list"MSG"output_tensor_lists"MSG"Argument ``gather_list`` must be specified on destination rank."MSG"Argument ``gather_list`` must NOT be specified "MSG"on non-destination ranks."MSG"tensor"MSG"gather_list"MSG"tensor"MSG"scatter_list"MSG"Argument ``scatter_list`` must be specified "MSG"on source rank."MSG"Argument ``scatter_list`` must NOT be specified "MSG"on non-source ranks."MSG"output"MSG"input_list"MSG"output"MSG"input"MSG"output_tensor_list"MSG"input_tensor_list"MSG"the new group'MSG's rank should be within the "MSG"the world_size set by init_process_group"MSG"PyTorch distributed training launch "MSG"helper utility that will spawn up "MSG"multiple distributed processes"MSG"--nnodes"MSG"The number of nodes to use for distributed "MSG"training"MSG"--node_rank"MSG"The rank of the node for multi-node distributed "MSG"training"MSG"--nproc_per_node"MSG"The number of processes to launch on each node, "MSG"for GPU training, this is recommended to be set "MSG"to the number of GPUs in your system so that "MSG"each process can be bound to a single GPU."MSG"--master_addr"MSG"127.0.0.1"MSG"Master node (rank 0)'MSG's free port that needs to "MSG"be used for communication during distributed "MSG"training"MSG"--use_env"MSG"store_true"MSG"Use environment variable to pass "MSG"'MSG'. For legacy reasons, the default value is False. "MSG"If set to True, the script will not pass "MSG"--local_rank as argument, and will instead set LOCAL_RANK."MSG"-m"MSG"--module"MSG"store_true"MSG"Changes each process to interpret the launch script "MSG"as a python module, executing with the same behavior as"MSG"'MSG'."MSG"--no_python"MSG"store_true"MSG"Do not prepend the training script with \"MSG" - just exec "MSG"it directly. Useful when the script is not a Python script."MSG"training_script"MSG"The full path to the single GPU training "MSG"program/script to be launched in parallel, "MSG"followed by all the arguments for the "MSG"training script"MSG"MASTER_ADDR"MSG"MASTER_PORT"MSG"WORLD_SIZE"MSG"OMP_NUM_THREADS"MSG"*****************************************\n"MSG"Setting OMP_NUM_THREADS environment variable for each process "MSG"to be {} in default, to avoid your system being overloaded, "MSG"please further tune the variable for optimal performance in "MSG"your application as needed. \n"MSG"*****************************************"MSG"OMP_NUM_THREADS"MSG"RANK"MSG"LOCAL_RANK"MSG"-u"MSG"-m"MSG"When using the 'MSG' flag, you must also set the 'MSG' flag."MSG"Don'MSG'--no_python'MSG'--module'MSG'Shutdown Proceed'MSG't match pg rank {}"MSG"world_size argument {} doesn'MSG't match pg rank {}"MSG"world_size argument {} doesn'MSG'{py_ast.body[0].name}'MSG'type'MSG' + wrong_type_lines = list(filter(lambda line: type_pattern.search(line[1]), lines)) + if len(wrong_type_lines) > 0: + raise RuntimeError("MSG" + str(wrong_type_lines[0][0]) + + "MSG"\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/ + + "MSG") + return None + elif len(type_lines) == 1: + return type_lines[0][1].strip() + + return_line = None + parameter_type_lines = [] + for line_num, line in type_lines: + if 'MSG' + try: + arrow_pos = type_line.index('MSG') + except ValueError: + raise RuntimeError("MSG") from None + return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2:].strip() +def try_real_annotations(fn, loc): + + try: + sig = inspect.signature(fn) + except ValueError: + return None all_annots = [sig.return_annotation] + [p.annotation for p in sig.parameters.values()] + if all(ann is sig.empty for ann in all_annots): + return None def as_ann(ann): + return ann if ann is not sig.empty else None arg_types = [ann_to_type(as_ann(p.annotation), loc) + for p in sig.parameters.values()] + return_type = ann_to_type(as_ann(sig.return_annotation), loc) + return arg_types, return_type +def get_enum_value_type(e: Type[enum.Enum], loc): + enum_values: List[enum.Enum] = list(e) + if not enum_values: + raise ValueError(f"MSG") types = {type(v.value) for v in enum_values} + ir_types = [try_ann_to_type(t, loc) for t in types] + + + + return torch._C.unify_type_list(ir_types) +def try_ann_to_type(ann, loc): + if ann is None: + return TensorType.get() + if inspect.isclass(ann) and issubclass(ann, torch.Tensor): + return TensorType.get() + if is_tuple(ann): + return TupleType([try_ann_to_type(a, loc) for a in ann.__args__]) + if is_list(ann): + elem_type = try_ann_to_type(ann.__args__[0], loc) + if elem_type: + return ListType(elem_type) + if is_dict(ann): + key = try_ann_to_type(ann.__args__[0], loc) + value = try_ann_to_type(ann.__args__[1], loc) + return DictType(key, value) + if is_optional(ann): + if issubclass(ann.__args__[1], type(None)): + contained = ann.__args__[0] + else: + contained = ann.__args__[1] + valid_type = try_ann_to_type(contained, loc) + msg = "MSG" + assert valid_type, msg.format(repr(ann), repr(contained)) + return OptionalType(valid_type) + if torch.distributed.rpc.is_available() and is_rref(ann): + return RRefType(try_ann_to_type(ann.__args__[0], loc)) + if is_future(ann): + return FutureType(try_ann_to_type(ann.__args__[0], loc)) + if ann is float: + return FloatType.get() + if ann is int: + return IntType.get() + if ann is str: + return StringType.get() + if ann is bool: + return BoolType.get() + if ann is Any: + return AnyType.get() + if ann is type(None): + return NoneType.get() + if inspect.isclass(ann) and hasattr(ann, "MSG"): + return InterfaceType(_qualified_name(ann)) + if ann is torch.device: + return DeviceObjType.get() + if ann is torch.dtype: + return IntType.get() + if inspect.isclass(ann) and issubclass(ann, enum.Enum): + qualified_name = _qualified_name(ann) + if _get_script_class(qualified_name) is None: + torch.jit._script._recursive_compile_class(ann, loc) + return EnumType(_qualified_name(ann), get_enum_value_type(ann, loc), list(ann)) + if inspect.isclass(ann): + qualified_name = _qualified_name(ann) + if _get_script_class(qualified_name) is not None: + return ClassType(qualified_name) + ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception) + if torch._jit_internal.can_compile_class(ann) and not issubclass(ann, ignored_builtin_classes): + torch.jit._script._recursive_compile_class(ann, loc) + return ClassType(qualified_name) + def fake_rcb(key): + return None + return torch._C._resolve_type_from_object(ann, loc, fake_rcb) +def ann_to_type(ann, loc): + the_type = try_ann_to_type(ann, loc) + if the_type is not None: + return the_type + raise ValueError(f"MSG") +__all__ = [ + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + + + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', +] +import torch +import sys +import ast +import inspect +import string +from textwrap import dedent +from torch._C._jit_tree_views import ( + ClassDef, Ident, Stmt, Decl, Def, Var, + EmptyTypeAnnotation, Param, ExprStmt, Assign, + Delete, Return, Raise, Assert, AugAssign, While, + For, If, Pass, Break, Continue, Apply, Dots, Select, + TrueLiteral, FalseLiteral, NoneLiteral, Starred, + ListLiteral, TupleLiteral, DictLiteral, Const, + StringLiteral, ListComp, Attribute, BinOp, UnaryOp, + SliceExpr, Subscript, TernaryIf, With, WithItem, Property, +) +from torch._utils_internal import get_source_lines_and_file from torch._jit_internal import SourceContext, should_drop, is_static_fn +import torch.jit.annotations +_reserved_prefix = 'MSG' +_reserved_names = {'MSG'} +_identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits) +def is_reserved_name(name): + return name.startswith(_reserved_prefix) or name in _reserved_names +pretty_node_names = { + ast.FunctionDef: "MSG", + ast.For: "MSG", + ast.Delete: "MSG", + ast.ClassDef: "MSG", + ast.With: "MSG", + ast.Raise: "MSG", + ast.Assert: "MSG", + ast.Import: "MSG", + ast.ImportFrom: "MSG", + ast.Global: "MSG", + ast.Break: "MSG", + ast.Continue: "MSG", +} node_start_tokens = { + ast.FunctionDef: "MSG", + ast.For: "MSG", + ast.Delete: "MSG", + ast.ClassDef: "MSG", + ast.With: "MSG", + ast.Raise: "MSG", + ast.Assert: "MSG", + ast.Import: "MSG", + ast.ImportFrom: "MSG", + ast.Global: "MSG", + ast.Break: "MSG", + ast.Continue: "MSG", +} pretty_node_names.update({ + ast.AsyncFunctionDef: "MSG", + ast.AsyncFor: "MSG", + ast.AsyncWith: "MSG", + ast.Try: "MSG", + ast.Nonlocal: "MSG", +}) node_start_tokens.update({ + ast.AsyncFunctionDef: "MSG", + ast.AsyncFor: "MSG", + ast.AsyncWith: "MSG", + ast.Try: "MSG", + ast.Nonlocal: "MSG", +}) if sys.version_info >= (3, 6): + pretty_node_names.update({ + ast.AnnAssign: "MSG", + }) + +class FrontendError(Exception): + def __init__(self, source_range, msg): + self.source_range = source_range + self.msg = msg self.error_report = torch._C.ErrorReport(self.source_range) def __str__(self): + return self.msg + self.error_report.what().lstrip() +class NotSupportedError(FrontendError): + pass +class UnsupportedNodeError(NotSupportedError): + def __init__(self, ctx, offending_node, reason='MSG'): + node_type = type(offending_node) + range_len = len(node_start_tokens.get(node_type, 'MSG')) + source_range = ctx.make_range(offending_node.lineno, + offending_node.col_offset, + offending_node.col_offset + range_len) + feature_name = pretty_node_names.get(node_type, node_type.__name__) + msg = "MSG".format(feature_name, reason + 'MSG' if reason else 'MSG') + super(UnsupportedNodeError, self).__init__(source_range, msg) +class FrontendTypeError(FrontendError): + pass +def build_withitems(ctx, items): + items = [build_withitem(ctx, i) for i in items] + return list(items) +def build_stmts(ctx, stmts): + stmts = [build_stmt(ctx, s) for s in stmts] + return list(filter(None, stmts)) +def get_class_properties(cls, self_name): + + props = inspect.getmembers( + cls, predicate=lambda m: isinstance(m, property)) + + unused_properties = getattr(cls, "MSG", []) + properties = [] + for prop in props: + if prop[0] not in unused_properties and not should_drop(prop[1].fget): + getter = get_jit_def(prop[1].fget, f"MSG", self_name=self_name) + setter = get_jit_def(prop[1].fset, f"MSG", self_name=self_name) if prop[1].fset else None + properties.append(Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter)) return properties +def get_jit_class_def(cls, self_name): + + + methods = inspect.getmembers( + cls, + predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) + and not is_static_fn(cls, m.__name__) + and m.__name__ in cls.__dict__ + ) + methods = [get_jit_def(method[1], + method[0], + self_name=self_name) for method in methods] properties = get_class_properties(cls, self_name) sourcelines, file_lineno, filename = get_source_lines_and_file(cls, torch._C.ErrorReport.call_stack()) + source = 'MSG'.join(sourcelines) + dedent_src = dedent(source) + py_ast = ast.parse(dedent_src) + leading_whitespace_len = len(source.split('MSG', 1)[0]) - len(dedent_src.split('MSG', 1)[0]) + ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, False) + return build_class_def(ctx, py_ast.body[0], methods, properties, self_name) +def get_jit_def(fn, def_name, self_name=None): + + sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack()) + source = 'MSG'.join(sourcelines) + dedent_src = dedent(source) + py_ast = ast.parse(dedent_src) + if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): + raise RuntimeError("MSG") + leading_whitespace_len = len(source.split('MSG', 1)[0]) - len(dedent_src.split('MSG', 1)[0]) + type_line = torch.jit.annotations.get_type_line(source) + ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True) + fn_def = py_ast.body[0] + if should_drop(fn): + unused_fn_def = ast.parse("MSG"Cannot call @unused methods\"MSG") + if len(unused_fn_def.body) != 1 or not isinstance(unused_fn_def.body[0], ast.FunctionDef): + raise RuntimeError("MSG") + unused_def = unused_fn_def.body[0] + fn_def.body = unused_def.body + fn_def.args.kwarg = fn_def.args.vararg = None + for arg in fn_def.args.args + fn_def.args.kwonlyargs: + arg.annotation = unused_def.args.args[0].annotation return build_def(ctx, fn_def, type_line, def_name, self_name=self_name) +class Builder(object): + def __call__(self, ctx, node): + method = getattr(self, 'MSG' + node.__class__.__name__, None) + if method is None: + raise UnsupportedNodeError(ctx, node) + return method(ctx, node) +def build_class_def(ctx, py_def, methods, properties, self_name): + r = ctx.make_range(py_def.lineno, py_def.col_offset, + py_def.col_offset + len("MSG")) + return ClassDef(Ident(r, self_name), [Stmt(method) for method in methods], properties) +def build_def(ctx, py_def, type_line, def_name, self_name=None): + body = py_def.body + r = ctx.make_range(py_def.lineno + len(py_def.decorator_list), + py_def.col_offset, + py_def.col_offset + len("MSG")) + param_list = build_param_list(ctx, py_def.args, self_name) + return_type = None + if getattr(py_def, 'MSG', None) is not None: + return_type = build_expr(ctx, py_def.returns) + decl = Decl(r, param_list, return_type) + is_method = self_name is not None + if type_line is not None: + type_comment_decl = torch._C.parse_type_comment(type_line) + decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method) return Def(Ident(r, def_name), + decl, + build_stmts(ctx, body)) +_vararg_kwarg_err = ("MSG" + "MSG") +def build_param_list(ctx, py_args, self_name): + if py_args.kwarg is not None: + expr = py_args.kwarg + ctx_range = ctx.make_range(expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg)) + raise NotSupportedError(ctx_range, _vararg_kwarg_err) + if py_args.vararg is not None: + expr = py_args.vararg + ctx_range = ctx.make_range(expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg)) + raise NotSupportedError(ctx_range, _vararg_kwarg_err) + if len(py_args.kw_defaults) > 0: + for arg in py_args.kw_defaults: + if arg is not None: + ctx_range = build_expr(ctx, arg).range() + raise NotSupportedError(ctx_range, _vararg_kwarg_err) + result = [build_param(ctx, arg, self_name, False) for arg in py_args.args] + result += [build_param(ctx, arg, self_name, True) for arg in py_args.kwonlyargs] + return result +def build_param(ctx, py_arg, self_name, kwarg_only): + + name = py_arg.arg + r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name)) + if getattr(py_arg, 'MSG', None) is not None: + annotation_expr = build_expr(ctx, py_arg.annotation) + elif self_name is not None and name == 'MSG': + annotation_expr = Var(Ident(r, self_name)) + else: + annotation_expr = EmptyTypeAnnotation(r) + return Param(annotation_expr, Ident(r, name), kwarg_only) +def get_default_args(fn): + if fn is None: + return {} signature = inspect.signature(fn) + return { + k: v.default + for k, v in signature.parameters.items() + if v.default is not inspect.Parameter.empty + } +def get_default_args_for_class(cls): + + + + methods = inspect.getmembers( + cls, + predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) + and not is_static_fn(cls, m.__name__) + and m.__name__ in cls.__dict__ + ) + + defaults = {method_name: get_default_args(method_impl) for method_name, method_impl in methods} return defaults +class WithItemBuilder(Builder): + @staticmethod + def build_withitem(ctx, item): + lineno = item.context_expr.lineno + start = item.context_expr.col_offset + end = start + len(pretty_node_names[ast.With]) + op_vars = item.optional_vars + r = ctx.make_range(lineno, start, end) return WithItem(r, build_expr(ctx, item.context_expr), build_expr(ctx, op_vars) if op_vars else None) +class StmtBuilder(Builder): + augassign_map = { + ast.Add: 'MSG', + ast.Sub: 'MSG', + ast.Mult: 'MSG', + ast.Div: 'MSG', + ast.Mod: 'MSG', + } @staticmethod + def build_Expr(ctx, stmt): + value = stmt.value + if value.__class__.__name__ == 'MSG': + return None + else: + return ExprStmt(build_expr(ctx, value)) @staticmethod + def build_Assign(ctx, stmt): + rhs = build_expr(ctx, stmt.value) + lhs = list(map(lambda x: build_expr(ctx, x), stmt.targets)) + return Assign(lhs, rhs) @staticmethod + def build_AnnAssign(ctx, stmt): + if stmt.value is None: + raise UnsupportedNodeError(ctx, stmt, reason='MSG') + rhs = build_expr(ctx, stmt.value) + lhs = build_expr(ctx, stmt.target) + the_type = build_expr(ctx, stmt.annotation) + return Assign([lhs], rhs, the_type) @staticmethod + def build_Delete(ctx, stmt): + if len(stmt.targets) > 1: + source_range = ctx.make_range(stmt.lineno, stmt.col_offset, + stmt.col_offset + len("MSG")) + raise NotSupportedError( + source_range, 'MSG') + return Delete(build_expr(ctx, stmt.targets[0])) @staticmethod + def build_Return(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("MSG")) + return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value)) @staticmethod + def build_Raise(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("MSG")) + expr = build_expr(ctx, stmt.exc) + return Raise(r, expr) @staticmethod + def build_Assert(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("MSG")) + test = build_expr(ctx, stmt.test) + msg = build_expr(ctx, stmt.msg) if stmt.msg is not None else None + return Assert(r, test, msg) @staticmethod + def build_AugAssign(ctx, stmt): + lhs = build_expr(ctx, stmt.target) + rhs = build_expr(ctx, stmt.value) + op = type(stmt.op) + if op in StmtBuilder.augassign_map: + op_token = StmtBuilder.augassign_map[op] + else: + raise NotSupportedError( + find_before(ctx, rhs.range().start, 'MSG', offsets=(-1, 0)), + "MSG" + op.__name__) + return AugAssign(lhs, op_token, rhs) @staticmethod + def build_While(ctx, stmt): + if stmt.orelse: + raise NotSupportedError(None, "MSG") + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("MSG")) + return While(r, build_expr(ctx, stmt.test), + build_stmts(ctx, stmt.body)) @staticmethod + def build_For(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("MSG")) + return For( + r, [build_expr(ctx, stmt.target)], + [build_expr(ctx, stmt.iter)], build_stmts(ctx, stmt.body)) @staticmethod + def build_If(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("MSG")) + return If(r, build_expr(ctx, stmt.test), + build_stmts(ctx, stmt.body), + build_stmts(ctx, stmt.orelse)) @staticmethod + def build_Print(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("MSG")) + if stmt.dest: + raise NotSupportedError(r, "MSG") + args = [build_expr(ctx, val) for val in stmt.values] + return ExprStmt(Apply(Var(Ident(r, "MSG")), args, [])) @staticmethod + def build_Pass(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("MSG")) + return Pass(r) @staticmethod + def build_Break(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("MSG")) + return Break(r) @staticmethod + def build_Continue(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("MSG")) + return Continue(r) @staticmethod + def build_With(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("MSG")) + return With(r, build_withitems(ctx, stmt.items), build_stmts(ctx, stmt.body)) class ExprBuilder(Builder): + binop_map = { + ast.Add: 'MSG', + ast.Sub: 'MSG', + ast.Mult: 'MSG', + ast.Div: 'MSG', + ast.Pow: 'MSG', + ast.Mod: 'MSG', + ast.FloorDiv: 'MSG', + ast.BitAnd: 'MSG', + ast.BitXor: 'MSG', + ast.BitOr: 'MSG', + ast.LShift: 'MSG', + ast.RShift: 'MSG', + } binop_map[ast.MatMult] = 'MSG' unop_map = { + ast.Not: 'MSG', + ast.USub: 'MSG', + ast.Invert: 'MSG', + } boolop_map = { + ast.And: 'MSG', + ast.Or: 'MSG', + } cmpop_map = { + ast.Eq: 'MSG', + ast.NotEq: 'MSG', + ast.LtE: 'MSG', + ast.Lt: 'MSG', + ast.GtE: 'MSG', + ast.Gt: 'MSG', + ast.Is: 'MSG', + ast.IsNot: 'MSG', + ast.In: 'MSG', + ast.NotIn: 'MSG', + } @staticmethod + def build_Attribute(ctx, expr): + base = build_expr(ctx, expr.value) + source = ctx.source.encode('MSG') def get_char(index): + return chr(source[index]) start_pos = base.range().end + 1 + while get_char(start_pos) in string.whitespace: + start_pos += 1 + end_pos = start_pos + len(expr.attr) + name_range = ctx.make_raw_range(start_pos, end_pos) + return Select(base, Ident(name_range, expr.attr)) @staticmethod + def build_Call(ctx, expr): + func = build_expr(ctx, expr.func) + args = [build_expr(ctx, py_arg) for py_arg in expr.args] + if hasattr(expr, 'MSG') and expr.starargs: + stararg_expr = build_expr(ctx, expr.starargs) + args += [Starred(stararg_expr.range(), stararg_expr)] + kwargs = [] + for kw in expr.keywords: + kw_expr = build_expr(ctx, kw.value) + if not kw.arg: + raise NotSupportedError(kw_expr.range(), 'MSG') + kwargs.append(Attribute(Ident(kw_expr.range(), kw.arg), kw_expr)) + return Apply(func, args, kwargs) @staticmethod + def build_Ellipsis(ctx, expr): + r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 3) + return Dots(r) @staticmethod + def build_Name(ctx, expr): + r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id)) + if expr.id.startswith(_reserved_prefix): + raise NotSupportedError(r, "MSG" + "MSG" + _reserved_prefix) + if expr.id == "MSG": + return TrueLiteral(r) + elif expr.id == "MSG": + return FalseLiteral(r) + elif expr.id == "MSG": + return NoneLiteral(r) + return Var(Ident(r, expr.id)) @staticmethod + def build_NameConstant(ctx, expr): + r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(str(expr.value))) + if expr.value is True: + return TrueLiteral(r) + elif expr.value is False: + return FalseLiteral(r) + elif expr.value is None: + return NoneLiteral(r) + else: + raise ValueError("MSG" + str(expr.value)) @staticmethod + def build_BinOp(ctx, expr): + lhs = build_expr(ctx, expr.left) + rhs = build_expr(ctx, expr.right) + op = type(expr.op) if op == ast.Div and not ctx.uses_true_division: + err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start) + raise FrontendError(err_range, 'MSG' + 'MSG' + 'MSG') + op_token = ExprBuilder.binop_map.get(op) + if op_token is None: + err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start) + raise NotSupportedError(err_range, "MSG" + op.__name__) + return BinOp(op_token, lhs, rhs) @staticmethod + def build_UnaryOp(ctx, expr): + sub_expr = build_expr(ctx, expr.operand) + op = type(expr.op) + op_token = ExprBuilder.unop_map.get(op) + if op_token is None: + raise NotSupportedError(expr.range(), "MSG" + op.__name__) + r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(op_token)) + return UnaryOp(r, op_token, sub_expr) @staticmethod + def build_BoolOp(ctx, expr): + if len(expr.values) < 2: + raise AssertionError("MSG" + str(len(expr.values))) + sub_exprs = [build_expr(ctx, sub_expr) for sub_expr in expr.values] + op = type(expr.op) + op_token = ExprBuilder.boolop_map.get(op) + if op_token is None: + err_range = ctx.make_raw_range(sub_exprs[0].range().end, sub_exprs[1].range().start) + raise NotSupportedError(err_range, "MSG" + op.__name__) + lhs = sub_exprs[0] + for rhs in sub_exprs[1:]: + lhs = BinOp(op_token, lhs, rhs) + return lhs @staticmethod + def build_IfExp(ctx, expr): + return TernaryIf(build_expr(ctx, expr.test), + build_expr(ctx, expr.body), + build_expr(ctx, expr.orelse)) @staticmethod + def build_Compare(ctx, expr): + operands = [build_expr(ctx, e) for e in [expr.left] + list(expr.comparators)] + result = None + for lhs, op_, rhs in zip(operands, expr.ops, operands[1:]): + op = type(op_) + op_token = ExprBuilder.cmpop_map.get(op) + r = ctx.make_raw_range(lhs.range().end, rhs.range().start) + if op_token is None: + raise NotSupportedError(r, "MSG" + op.__name__) if op == ast.NotIn: + in_expr = BinOp('MSG', lhs, rhs) + cmp_expr = UnaryOp(r, 'MSG', in_expr) + else: + cmp_expr = BinOp(op_token, lhs, rhs) if result is None: + result = cmp_expr + else: + result = BinOp('MSG', result, cmp_expr) + return result @staticmethod + def build_Subscript(ctx, expr): + def build_SliceExpr(ctx, base, slice_expr): + lower = build_expr(ctx, slice_expr.lower) if slice_expr.lower is not None else None + upper = build_expr(ctx, slice_expr.upper) if slice_expr.upper is not None else None + step = build_expr(ctx, slice_expr.step) if slice_expr.step is not None else None + return SliceExpr(base.range(), lower, upper, step) def build_Index(ctx, base, index_expr): + if isinstance(index_expr.value, ast.Tuple) or \ + isinstance(index_expr.value, ast.List): + raise NotSupportedError(base.range(), + "MSG" + "MSG") + return build_expr(ctx, index_expr.value) def build_ExtSlice(ctx, base, extslice): + sub_exprs = [] + for expr in extslice.dims: + sub_type = type(expr) + if sub_type is ast.Index: + sub_exprs.append(build_Index(ctx, base, expr)) + elif sub_type is ast.Slice: + sub_exprs.append(build_SliceExpr(ctx, base, expr)) + elif sub_type is ast.Ellipsis: + sub_exprs.append(Dots(base.range())) + else: + raise NotSupportedError(base.range(), + "MSG" + "MSG".format(sub_type)) + return sub_exprs base = build_expr(ctx, expr.value) + sub_type = type(expr.slice) + if sub_type is ast.Index: + if isinstance(expr.slice.value, ast.Tuple): + indices = [build_expr(ctx, index_expr) for index_expr in expr.slice.value.elts] + return Subscript(base, indices) + else: + return Subscript(base, [build_expr(ctx, expr.slice.value)]) + elif sub_type is ast.Slice: + return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)]) + elif sub_type is ast.ExtSlice: + return Subscript(base, build_ExtSlice(ctx, base, expr.slice)) + elif sys.version_info >= (3, 9): + if sub_type is ast.Tuple: + indices = [] + for index_expr in expr.slice.elts: + if isinstance(index_expr, ast.Slice): + indices.append(build_SliceExpr(ctx, base, index_expr)) + else: + indices.append(build_expr(ctx, index_expr)) + return Subscript(base, indices) + return Subscript(base, [build_expr(ctx, expr.slice)]) + else: + raise NotSupportedError(base.range(), "MSG") @staticmethod + def build_List(ctx, expr): + return ListLiteral(ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1), + [build_expr(ctx, e) for e in expr.elts]) @staticmethod + def build_Tuple(ctx, expr): + return TupleLiteral(ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1), + [build_expr(ctx, e) for e in expr.elts]) @staticmethod + def build_Dict(ctx, expr): + return DictLiteral(ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1), + [build_expr(ctx, e) for e in expr.keys], [build_expr(ctx, e) for e in expr.values]) @staticmethod + def build_Num(ctx, expr): + value = str(expr.n) + r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(value)) + return Const(r, value) @staticmethod + def build_Constant(ctx, expr): + value = expr.value + if value is None or isinstance(value, bool): + return ExprBuilder.build_NameConstant(ctx, expr) + if isinstance(value, (int, float)): + return ExprBuilder.build_Num(ctx, expr) + elif isinstance(value, str): + return ExprBuilder.build_Str(ctx, expr) + elif isinstance(value, type(Ellipsis)): + return ExprBuilder.build_Ellipsis(ctx, expr) + else: + error_range = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(str(value))) + raise FrontendError(error_range, "MSG") @staticmethod + def build_Str(ctx, expr): + value = str(expr.s) + r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1) + return StringLiteral(r, value) @staticmethod + def build_JoinedStr(ctx, expr): + s = 'MSG' + args = [] + for value in expr.values: + r = ctx.make_range(value.lineno, value.col_offset, value.col_offset + 1) + if isinstance(value, ast.FormattedValue): + if value.conversion != -1: + raise NotSupportedError(r, 'MSG't support conversion in JoinedStr'MSG'Don\'MSG') + s += 'MSG' + args.append(build_expr(ctx, value.value)) + elif isinstance(value, ast.Str): + s += value.s + else: + raise NotSupportedError(r, 'MSG') r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1) + return Apply(Select(StringLiteral(r, s), Ident(r, 'MSG')), args, []) @staticmethod + def build_ListComp(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset) + if (len(stmt.generators) > 1): + raise NotSupportedError(r, "MSG") if (len(stmt.generators[0].ifs) != 0): + raise NotSupportedError(r, "MSG") elt_expr = build_expr(ctx, stmt.elt) + target_expr = build_expr(ctx, stmt.generators[0].target) iter_expr = build_expr(ctx, stmt.generators[0].iter) + return ListComp(r, elt_expr, target_expr, iter_expr) @staticmethod + def build_Starred(ctx, expr): + r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1) + return Starred(r, build_expr(ctx, expr.value)) build_expr = ExprBuilder() +build_stmt = StmtBuilder() +build_withitem = WithItemBuilder() def find_before(ctx, pos, substr, offsets=(0, 0)): + new_pos = ctx.source[:pos].rindex(substr) + return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1]) +import torch from typing import Tuple, Optional, List from torch import Tensor, _VF from torch.nn.utils.rnn import PackedSequence import warnings class QuantizedLinear(torch.jit.ScriptModule): + __constants__ = ['MSG', 'MSG'] def __init__(self, other): + super(QuantizedLinear, self).__init__() + self.in_features = other.in_features + self.out_features = other.out_features + self.weight, self.col_offsets, self.scale, self.zero_point = torch.fbgemm_linear_quantize_weight( + other.weight.clone(memory_format=torch.contiguous_format).float()) + self.weight = torch.nn.Parameter(self.weight, requires_grad=False) + self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False) + assert other.bias is not None, 'MSG' + self.bias = torch.nn.Parameter(other.bias.clone(memory_format=torch.contiguous_format).float(), requires_grad=False) self.register_buffer( + 'MSG', + torch.fbgemm_pack_quantized_matrix(self.weight.clone(memory_format=torch.contiguous_format))) @torch.jit.script_method + def _unpack(self): + self.packed_tensor_ptr.set_( + torch.fbgemm_pack_quantized_matrix(self.weight)) @torch.jit.script_method + def _pack(self): + self.packed_tensor_ptr.set_( + torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()) @torch.jit.script_method + def forward(self, input): + out = torch.fbgemm_linear_int8_weight_fp32_activation( + input.float(), self.weight, self.packed_tensor_ptr, self.col_offsets, + self.scale, self.zero_point, self.bias) + return out.to(input.dtype) def extra_repr(self): + repr = 'MSG' \ + 'MSG'.format(**self.__dict__) + return repr +class QuantizedLinearFP16(torch.jit.ScriptModule): def __init__(self, other): + super(QuantizedLinearFP16, self).__init__() + self.in_features = other.in_features + self.out_features = other.out_features + self.original_weight = other.weight + self.weight = torch.fbgemm_pack_gemm_matrix_fp16( + other.weight.clone(memory_format=torch.contiguous_format).float()) + assert other.bias is not None, 'MSG' + self.bias = torch.nn.Parameter(other.bias.clone(memory_format=torch.contiguous_format).float(), requires_grad=False) + self.register_buffer('MSG', self.weight) @torch.jit.script_method + def _unpack(self): + self.packed_weight.set_( + torch.fbgemm_pack_gemm_matrix_fp16( + self.original_weight)) @torch.jit.script_method + def _pack(self): + self.packed_weight.set_( + torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()) @torch.jit.script_method + def forward(self, input): + out = torch.fbgemm_linear_fp16_weight_fp32_activation( + input.float(), self.packed_weight, self.bias) + return out def extra_repr(self): + repr = 'MSG'.format(**self.__dict__) + return repr +class QuantizedRNNCellBase(torch.jit.ScriptModule): + __constants__ = ['MSG', 'MSG', 'MSG', 'MSG', 'MSG', + 'MSG', 'MSG'] def __init__(self, other): + super(QuantizedRNNCellBase, self).__init__() + self.input_size = other.input_size + self.hidden_size = other.hidden_size + self.bias = other.bias + if not self.bias: + raise ValueError("MSG") weight_ih, col_offsets_ih, self.scale_ih, self.zero_point_ih = \ + torch.fbgemm_linear_quantize_weight(other.weight_ih.clone(memory_format=torch.contiguous_format).float()) + self.register_buffer('MSG', weight_ih) + self.register_buffer('MSG', col_offsets_ih) + weight_hh, col_offsets_hh, self.scale_hh, self.zero_point_hh = \ + torch.fbgemm_linear_quantize_weight(other.weight_hh.clone(memory_format=torch.contiguous_format).float()) + self.register_buffer('MSG', weight_hh) + self.register_buffer('MSG', col_offsets_hh) packed_ih = torch.fbgemm_pack_quantized_matrix(self.weight_ih) + self.register_buffer('MSG', packed_ih) + packed_hh = torch.fbgemm_pack_quantized_matrix(self.weight_hh) + self.register_buffer('MSG', packed_hh) self.bias_ih = torch.nn.Parameter(other.bias_ih.clone(memory_format=torch.contiguous_format).float(), requires_grad=False) + self.bias_hh = torch.nn.Parameter(other.bias_hh.clone(memory_format=torch.contiguous_format).float(), requires_grad=False) def extra_repr(self): + s = 'MSG' + if 'MSG' in self.__dict__ and self.bias is not True: + s += 'MSG' + if 'MSG' in self.__dict__ and self.nonlinearity != "MSG": + s += 'MSG' + return s.format(**self.__dict__) @torch.jit.script_method + def check_forward_input(self, input): + if input.size(1) != self.input_size: + raise RuntimeError( + "MSG".format( + input.size(1), self.input_size)) @torch.jit.script_method + def check_forward_hidden(self, input, hx, hidden_label='MSG'): + if input.size(0) != hx.size(0): + raise RuntimeError( + "MSG".format( + input.size(0), hidden_label, hx.size(0))) if hx.size(1) != self.hidden_size: + raise RuntimeError( + "MSG".format( + hidden_label, hx.size(1), self.hidden_size)) + + + + @torch.jit.script_method + def _unpack(self): + self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(self.weight_ih)) + self.packed_hh.set_(torch.fbgemm_pack_quantized_matrix(self.weight_hh)) + @torch.jit.script_method + def _pack(self): + self.packed_ih.set_( + torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()) + self.packed_hh.set_( + torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()) +class QuantizedRNNCell(QuantizedRNNCellBase): + __constants__ = ['MSG', 'MSG', 'MSG', 'MSG', 'MSG', + 'MSG', 'MSG', 'MSG'] def __init__(self, other): + super(QuantizedRNNCell, self).__init__(other) + self.nonlinearity = other.nonlinearity @torch.jit.script_method + def forward(self, input, hx=None): + self.check_forward_input(input) + if hx is None: + hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + self.check_forward_hidden(input, hx, 'MSG') + if self.nonlinearity == "MSG": + ret = _VF.quantized_rnn_tanh_cell( + input, hx, self.weight_ih, self.weight_hh, self.bias_ih, + self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih, + self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih, + self.zero_point_hh + ) + elif self.nonlinearity == "MSG": + ret = _VF.quantized_rnn_relu_cell( + input, hx, self.weight_ih, self.weight_hh, self.bias_ih, + self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih, + self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih, + self.zero_point_hh + ) + else: + ret = input + raise RuntimeError( + "MSG".format(self.nonlinearity)) + return ret +class QuantizedLSTMCell(QuantizedRNNCellBase): + def __init__(self, other): + super(QuantizedLSTMCell, self).__init__(other) @torch.jit.script_method + def forward(self, input, hx=None): + self.check_forward_input(input) + if hx is None: + zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + hx = (zeros, zeros) + self.check_forward_hidden(input, hx[0], 'MSG') + self.check_forward_hidden(input, hx[1], 'MSG') + return _VF.quantized_lstm_cell( + input, hx, self.weight_ih, self.weight_hh, self.bias_ih, + self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih, + self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih, + self.zero_point_hh + ) +class QuantizedGRUCell(QuantizedRNNCellBase): + def __init__(self, other): + super(QuantizedGRUCell, self).__init__(other) @torch.jit.script_method + def forward(self, input, hx=None): + self.check_forward_input(input) + if hx is None: + hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + self.check_forward_hidden(input, hx, 'MSG') + return _VF.quantized_gru_cell( + input, hx, self.weight_ih, self.weight_hh, self.bias_ih, + self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih, + self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih, + self.zero_point_hh + ) +def apply_permutation(tensor, permutation, dim=1): + + return tensor.index_select(dim, permutation) +class QuantizedRNNBase(torch.jit.ScriptModule): + __constants__ = ['MSG', 'MSG', 'MSG', 'MSG', 'MSG', + 'MSG', 'MSG', 'MSG', 'MSG'] def __init__(self, other, dtype=torch.int8): + super(QuantizedRNNBase, self).__init__() + self.mode = other.mode + self.input_size = other.input_size + self.hidden_size = other.hidden_size + self.num_layers = other.num_layers + self.bias = other.bias + self.batch_first = other.batch_first + if self.mode != 'MSG': + assert not self.batch_first + self.dropout = other.dropout + self.bidirectional = other.bidirectional + num_directions = 2 if self.bidirectional else 1 + self.dtype = dtype assert self.bias if self.mode != 'MSG' and self.mode != 'MSG': + raise RuntimeError('MSG') if dtype != torch.int8 and dtype != torch.float16: + raise RuntimeError('MSG'.format(dtype)) self.all_weights = [] + for layer in range(self.num_layers): + for direction in range(num_directions): + layer_input_size = self.input_size if layer == 0 else self.hidden_size * num_directions suffix = 'MSG' if direction == 1 else 'MSG' def get_weight_bias(ihhh): + weight_name = 'MSG'.format(ihhh, layer, suffix) + bias_name = 'MSG'.format(ihhh, layer, suffix) weight = getattr(other, weight_name) + bias = getattr(other, bias_name) + return weight, bias weight_ih, bias_ih = get_weight_bias('MSG') + weight_hh, bias_hh = get_weight_bias('MSG') if dtype == torch.int8: + cell_params = torch.ops.quantized.make_quantized_cell_params( + weight_ih, weight_hh, bias_ih, bias_hh) + else: + packed_ih = torch.ops.quantized.linear_prepack_fp16( + weight_ih.float(), bias_ih) + packed_hh = torch.ops.quantized.linear_prepack_fp16( + weight_hh.float(), bias_hh) cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( + packed_ih, packed_hh) setattr(self, 'MSG'.format(layer, suffix), cell_params) + self.all_weights.append(cell_params) @torch.jit.script_method + def check_input(self, input, batch_sizes): + expected_input_dim = 2 if batch_sizes is not None else 3 + if input.dim() != expected_input_dim: + raise RuntimeError( + 'MSG'.format( + expected_input_dim, input.dim())) + if self.input_size != input.size(-1): + raise RuntimeError( + 'MSG'.format( + self.input_size, input.size(-1))) @torch.jit.script_method + def get_expected_hidden_size(self, input, batch_sizes): + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = (self.num_layers * num_directions, + mini_batch, self.hidden_size) + return expected_hidden_size @torch.jit.script_method + def check_hidden_size(self, hx, expected_hidden_size, msg='MSG'): + if hx.size() != expected_hidden_size: + raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) @torch.jit.script_method + def check_forward_args(self, input, hidden, batch_sizes): + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + self.check_hidden_size(hidden, expected_hidden_size, msg='MSG') @torch.jit.script_method + def permute_hidden(self, hx, permutation): + if permutation is None: + return hx + return apply_permutation(hx, permutation) +class QuantizedLSTM(QuantizedRNNBase): + __overloads__ = {'MSG': ['MSG', 'MSG']} def __init__(self, other, dtype): + super(QuantizedLSTM, self).__init__(other, dtype) @torch.jit.script_method + def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): + if hx is None: + num_directions = 2 if self.bidirectional else 1 + zeros = torch.zeros(self.num_layers * num_directions, + max_batch_size, self.hidden_size, + dtype=input.dtype, device=input.device) + hx = (zeros, zeros) + else: + hx = self.permute_hidden(hx, sorted_indices) self.check_forward_args(input, hx, batch_sizes) + assert batch_sizes is None + result = torch.quantized_lstm(input, hx, self.all_weights, self.bias, self.num_layers, + float(self.dropout), self.training, self.bidirectional, + self.batch_first, dtype=self.dtype, use_dynamic=False) + output = result[0] + hidden = result[1:] return output, hidden @torch.jit.script_method + def forward_tensor(self, input, hx=None): + batch_sizes = None + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices) return output, self.permute_hidden(hidden, unsorted_indices) @torch.jit.script_method + def forward_packed(self, input, hx=None): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + max_batch_size = int(max_batch_size) output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices) output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) + return output, self.permute_hidden(hidden, unsorted_indices) + @torch.jit.script_method + def permute_hidden(self, hx, permutation): + if permutation is None: + return hx + return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation) @torch.jit.script_method + def check_forward_args(self, input, hidden, batch_sizes): + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) self.check_hidden_size(hidden[0], expected_hidden_size, + 'MSG') + self.check_hidden_size(hidden[1], expected_hidden_size, + 'MSG') def forward(self, input, hx=None): + if isinstance(input, PackedSequence): + return self.forward_packed(input, hx) + else: + return self.forward_tensor(input, hx) +class QuantizedGRU(QuantizedRNNBase): + __overloads__ = {'MSG': ['MSG', 'MSG']} @torch.jit.script_method + def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = torch.zeros(self.num_layers * num_directions, + max_batch_size, self.hidden_size, + dtype=input.dtype, device=input.device) + else: + hx = self.permute_hidden(hx, sorted_indices) self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = torch.quantized_gru(input, hx, self.all_weights, self.bias, self.num_layers, + float(self.dropout), self.training, self.bidirectional, + self.batch_first) + else: + result = torch.quantized_gru(input, batch_sizes, hx, self.all_weights, self.bias, self.num_layers, + float(self.dropout), self.training, self.bidirectional) output = result[0] + hidden = result[1] return output, hidden @torch.jit.script_method + def forward_tensor(self, input, hx=None): + batch_sizes = None + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices) + return output, self.permute_hidden(hidden, unsorted_indices) @torch.jit.script_method + def forward_packed(self, input, hx=None): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + max_batch_size = int(max_batch_size) output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices) output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) + return output, self.permute_hidden(hidden, unsorted_indices) def forward(self, input, hx=None): + if isinstance(input, PackedSequence): + return self.forward_packed(input, hx) + else: + return self.forward_tensor(input, hx) +def quantize_rnn_cell_modules(module): + warnings.warn("MSG" + "MSG") + reassign = {} + for name, mod in module.named_modules(): + if mod is module: + continue + new_mod = quantize_rnn_cell_modules(mod) + if new_mod is not mod: + reassign[name] = new_mod + for name, mod in reassign.items(): + setattr(module, name, mod) + if isinstance(module, torch.nn.LSTMCell): + return QuantizedLSTMCell(module) + if isinstance(module, torch.nn.GRUCell): + return QuantizedGRUCell(module) + if isinstance(module, torch.nn.RNNCell): + return QuantizedRNNCell(module) + return module +def quantize_linear_modules(module, dtype=torch.int8): + warnings.warn("MSG" + "MSG") reassign = {} + for name, mod in module.named_modules(): + if mod is module: + continue + new_mod = quantize_linear_modules(mod, dtype) + if new_mod is not mod: + reassign[name] = new_mod for name, mod in reassign.items(): + setattr(module, name, mod) + if isinstance(module, torch.nn.Linear): + if dtype == torch.int8: + return QuantizedLinear(module) + elif dtype == torch.float16: + return QuantizedLinearFP16(module) + else: + raise RuntimeError( + "MSG".format(dtype)) + return module +def quantize_rnn_modules(module, dtype=torch.int8): + warnings.warn("MSG" + "MSG") + reassign = {} + for name, mod in module.named_modules(): + if mod is module: + continue + new_mod = quantize_rnn_modules(mod, dtype) + if new_mod is not mod: + reassign[name] = new_mod for name, mod in reassign.items(): + setattr(module, name, mod) + if isinstance(module, torch.nn.LSTM): + if dtype != torch.int8 and dtype != torch.float16: + raise RuntimeError("MSG".format(dtype)) + return QuantizedLSTM(module, dtype) + if isinstance(module, torch.nn.GRU): + return QuantizedGRU(module) + return module +import torch.jit +from torch.jit._builtins import _find_builtin +import inspect +import textwrap +def _hidden(name): + return name.startswith('MSG') and not name.startswith('MSG') def _emit_type(type): + return str(type) def _emit_arg(indent, i, arg): + v = "MSG".format(arg.name, _emit_type(arg.type)) + default = arg.default_value + if default is not None: + v = "MSG".format(v, str(default)) + if i > 0: + v = "MSG".format("MSG" * indent, v) + return v def _emit_args(indent, arguments): + return "MSG".join(_emit_arg(indent, i, arg) for i, arg in enumerate(arguments)) def _emit_ret(ret): + return _emit_type(ret.type) def _emit_rets(returns): + if len(returns) == 1: + return _emit_ret(returns[0]) + return "MSG".format("MSG".join(_emit_ret(r) for r in returns)) def _emit_schema(mod, name, schema, arg_start=0, padding=4): + if mod is None: + qualified_name = name + else: + qualified_name = "MSG".format(mod, name) + schema_str = "MSG".format(qualified_name, + _emit_args(len(qualified_name) + 1 + padding, schema.arguments[arg_start:]), + _emit_rets(schema.returns)) + return schema_str def _get_tensor_ops(): + def is_tensor_method(schema): + if len(schema.arguments) == 0: + return False + self = schema.arguments[0] + if self.name != 'MSG': + return False + if not self.type.isSubtypeOf(torch._C.TensorType.get()): + return False + return True methods = [] + + for elem in dir(torch.Tensor): + if not _hidden(elem): + schemas = torch._C._jit_get_schemas_for_operator("MSG" + elem) + for schema in schemas: + if is_tensor_method(schema): + methods.append(_emit_schema('MSG', elem, schema, arg_start=1)) return "MSG", methods def _get_nn_functional_ops(): + functions = [] + mod = torch.nn.functional + name = mod.__name__ + for elem in dir(torch.nn.functional): + attr = getattr(mod, elem) + if not inspect.isfunction(attr) or _hidden(elem[0]): + continue attr_module = inspect.getmodule(attr) + if not attr_module: + raise RuntimeError(f'MSG') if 'MSG' not in attr_module.__name__: + continue try: + scripted = torch.jit.script(attr) + schema = scripted.schema + functions.append(_emit_schema(name, elem, schema)) + except: + pass + for mod in torch.jit._builtins._modules_containing_builtins: + name = mod.__name__ + for elem in dir(mod): + builtin = _find_builtin(getattr(mod, elem)) + if builtin is not None: + schemas = torch._C._jit_get_schemas_for_operator(builtin) + for schema in schemas: + if not _hidden(elem): + functions.append(_emit_schema(name, elem, schema)) + return "MSG", functions def _get_builtins_helper(): + builtins = [] + for fn, _builtin_name in torch.jit._builtins._builtin_ops: + mod = inspect.getmodule(fn) if not hasattr(fn, 'MSG'): + continue + if not mod: + continue + if _hidden(fn.__name__) or _hidden(fn.__qualname__) or _hidden(mod.__name__): + continue if 'MSG' in mod.__name__: + continue builtins.append((fn, _builtin_name)) return builtins def _is_math_fn(fn): + mod = inspect.getmodule(fn) + if not mod: + raise RuntimeError(f'MSG') return mod.__name__ == 'MSG' +def _get_torchscript_builtins(): + functions = [] + builtins = filter(lambda fn: not _is_math_fn(fn[0]), _get_builtins_helper()) + builtins_list = list(builtins) + + for fn, _builtin_name in builtins_list: + mod = inspect.getmodule(fn) + if not mod: + raise RuntimeError(f'MSG') + builtin = _find_builtin(fn) + if builtin is not None: + schemas = torch._C._jit_get_schemas_for_operator(builtin) + for schema in schemas: + functions.append(_emit_schema(mod.__name__, fn.__name__, schema)) + pass return "MSG", functions +def _get_math_builtins(): + functions = [] + builtins = filter(lambda fn: _is_math_fn(fn[0]), _get_builtins_helper()) + builtins_list = list(builtins) + + for fn, _builtin_name in builtins_list: + mod = inspect.getmodule(fn) + if not mod: + raise RuntimeError(f'MSG') + builtin = _find_builtin(fn) + if builtin is not None: + schemas = torch._C._jit_get_schemas_for_operator(builtin) + for schema in schemas: + schema_str = _emit_schema(mod.__name__, fn.__name__, schema) + if 'MSG' in schema_str: + continue + functions.append(schema) + pass return "MSG", functions +def _get_global_builtins(): + + supported_builtins = [ + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + ] op_renames = { + 'MSG': 'MSG', + 'MSG': 'MSG', + 'MSG': 'MSG', + 'MSG': 'MSG', + 'MSG': 'MSG', + 'MSG': 'MSG', + 'MSG': 'MSG', + } schemaless_op_explanations = { + 'MSG': 'MSG', + 'MSG': 'MSG', + 'MSG': 'MSG', + 'MSG': 'MSG', + 'MSG': 'MSG', + 'MSG': 'MSG', + 'MSG': 'MSG', + 'MSG': 'MSG', + } magic_methods = [ + ('MSG', 'MSG'), + ('MSG', 'MSG'), + ('MSG', 'MSG'), + ('MSG', 'MSG'), + ('MSG', 'MSG'), + ('MSG', 'MSG'), + ('MSG', 'MSG'), + ] magic_methods_rows = [] + for fn, magic_method in magic_methods: + magic_methods_rows.append('MSG'.format(fn, magic_method)) schematized_ops = [] + schemaless_ops = [] for fn in supported_builtins: + op_name = 'MSG'.format(fn) + if fn in op_renames: + op_name = op_renames[fn] + schemas = torch._C._jit_get_schemas_for_operator(op_name) + for s in schemas: + schematized_ops.append(_emit_schema(None, fn, s, padding=0)) + if len(schemas) > 0: + schematized_ops.append('MSG') + else: + table_row = 'MSG'.format(fn, schemaless_op_explanations[fn]) + schemaless_ops.append(table_row) schematized_ops_str = 'MSG'.join(schematized_ops) + schemaless_ops_str = 'MSG'.join(schemaless_ops) + magic_methods_rows_str = 'MSG'.join(magic_methods_rows) + schematized_ops_str = textwrap.indent(schematized_ops_str, 'MSG') + schemaless_ops_str = textwrap.indent(schemaless_ops_str, 'MSG') + magic_methods_rows_str = textwrap.indent(magic_methods_rows_str, 'MSG') + section = .format(schemaless_ops_str, magic_methods_rows_str, schematized_ops_str) return "MSG", section +def _list_supported_ops(): + def emit_block(decls): + return 'MSG'.format('MSG'.join('MSG'.format(d) for d in decls)) body = 'MSG' + op_gathering_fns = ( + _get_tensor_ops, + _get_nn_functional_ops, + _get_torchscript_builtins, + _get_global_builtins, + _get_math_builtins, + ) + for fn in op_gathering_fns: + header, items = fn() + link_target = header.replace('MSG', 'MSG').replace('MSG', 'MSG').lower().replace('MSG', 'MSG') + if isinstance(items, str): + section = "MSG".format(header, 'MSG' * len(header), items) + else: + section = "MSG".format(header, 'MSG' * len(header), emit_block(items)) + section = 'MSG'.format(link_target) + 'MSG' + section + body += section return body __doc__ = _list_supported_ops() +import torch.jit +from textwrap import dedent from typing import Dict, Any def execWrapper(code, glob, loc): + exec(code, glob, loc) def _gen_unsupported_methods_properties(): + tensor_attrs = set(filter(lambda x: x[0] != "MSG", dir(torch.Tensor))) + tensor = torch.tensor([2]) + funcs_template = dedent() deprecated_apis = set(["MSG", "MSG", "MSG", "MSG", "MSG", "MSG", "MSG", "MSG", "MSG"]) + tensor_attrs = tensor_attrs - deprecated_apis properties = [] + methods = [] + sorted_tensor_attrs = sorted(list(tensor_attrs), key=lambda x: x.lower()) + for attr in sorted_tensor_attrs: + funcs_str = funcs_template.format(op=attr) + scope: Dict[str, Any] = {} + execWrapper(funcs_str, globals(), scope) + try: + cu = torch.jit.CompilationUnit(funcs_str) + except Exception as e: + if "MSG" not in repr(e): + continue + attr_repr = repr(getattr(tensor, attr)) + if "MSG" in attr_repr or "MSG" in attr_repr: + methods.append(attr) + else: + properties.append(attr) mapped_methods = map(lambda x: "MSG" + x + r"MSG", methods) + mapped_properties = map(lambda x: "MSG" + x + r"MSG", properties) + return "MSG".join(mapped_methods), "MSG".join(mapped_properties) +def _list_unsupported_tensor_ops(): + header = + methods, properties = _gen_unsupported_methods_properties() + return header + "MSG" + methods + + "MSG" + properties __doc__ = _list_unsupported_tensor_ops() +import torch from torch.utils import set_module +from torch.jit._builtins import _register_builtin +from torch._jit_internal import Future set_module(Future, "MSG") +def fork(func, *args, **kwargs): + + return torch._C.fork(func, *args, **kwargs) +def wait(future): + + return torch._C.wait(future) +_register_builtin(wait, "MSG") +import math +import warnings import torch +import torch.backends.cudnn as cudnn from torch._six import PY37 +from ..nn.modules.utils import _single, _pair, _triple, _quadruple, _list_with_default from collections import OrderedDict +from typing import Dict, Optional _builtin_table: Optional[Dict[int, str]] = None _modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg) _builtin_ops = [ + + (_pair, "MSG"), + (_quadruple, "MSG"), + (_single, "MSG"), + (_triple, "MSG"), + (_list_with_default, "MSG"), + (OrderedDict, "MSG"), + (dict, "MSG"), + (cudnn.is_acceptable, "MSG"), + (math.ceil, "MSG"), + (math.copysign, "MSG"), + (math.erf, "MSG"), + (math.erfc, "MSG"), + (math.exp, "MSG"), + (math.expm1, "MSG"), + (math.fabs, "MSG"), + (math.floor, "MSG"), + (math.gamma, "MSG"), + (math.lgamma, "MSG"), + (math.log, "MSG"), + (math.log10, "MSG"), + (math.log1p, "MSG"), + (math.pow, "MSG"), + (math.sqrt, "MSG"), + (math.isnan, "MSG"), + (math.asinh, "MSG"), + (math.atanh, "MSG"), + (math.cosh, "MSG"), + (math.sinh, "MSG"), + (math.tanh, "MSG"), + (math.acos, "MSG"), + (math.asin, "MSG"), + (math.atan, "MSG"), + (math.atan2, "MSG"), + (math.cos, "MSG"), + (math.sin, "MSG"), + (math.tan, "MSG"), + (math.asinh, "MSG"), + (math.atanh, "MSG"), + (math.acosh, "MSG"), + (math.sinh, "MSG"), + (math.cosh, "MSG"), + (math.tanh, "MSG"), + (math.fmod, "MSG"), + (math.modf, "MSG"), + (math.factorial, "MSG"), + (math.frexp, "MSG"), + (math.isnan, "MSG"), + (math.isinf, "MSG"), + (math.degrees, "MSG"), + (math.radians, "MSG"), + (math.ldexp, "MSG"), + (torch.autograd.grad, "MSG"), + (torch.autograd.backward, "MSG"), + (torch._C._infer_size, "MSG"), + (torch.nn.functional._no_grad_embedding_renorm_, "MSG"), + (torch.nn.functional.assert_int_or_pair, "MSG"), + (torch.nn.init._no_grad_fill_, "MSG"), + (torch.nn.init._no_grad_normal_, "MSG"), + (torch.nn.init._no_grad_uniform_, "MSG"), + (torch.nn.init._no_grad_zero_, "MSG"), + (torch._C._get_tracing_state, "MSG"), + (warnings.warn, "MSG"), + (torch._VF.stft, "MSG"), + (torch._VF.istft, "MSG"), + (torch._VF.cdist, "MSG"), + (torch._VF.norm, "MSG"), + (torch._VF.unique_dim, "MSG"), + (torch._VF.unique_consecutive, "MSG"), + (torch._VF.nuclear_norm, "MSG"), + (torch._VF.frobenius_norm, "MSG"), +] def _gen_torch_functional_registered_ops(): + + + + + ops = ["MSG", "MSG", "MSG", "MSG", "MSG", "MSG", "MSG", "MSG"] + return set(getattr(torch.functional, name) for name in ops) _functional_registered_ops = _gen_torch_functional_registered_ops() def _is_special_functional_bound_op(fn): + return fn in _functional_registered_ops +def _get_builtin_table(): + global _builtin_table + if _builtin_table is not None: + return _builtin_table + _builtin_table = {} def register_all(mod): + for name in dir(mod): + v = getattr(mod, name) + if callable(v) and not _is_special_functional_bound_op(v) and v is not torch.no_grad: + _builtin_ops.append((v, "MSG" + name)) + for mod in _modules_containing_builtins: + register_all(mod) _builtin_ops.append((math.gcd, "MSG")) + _builtin_ops.append((math.isfinite, "MSG")) + if PY37: + _builtin_ops.append((math.remainder, "MSG")) import torch.distributed.autograd as dist_autograd + if dist_autograd.is_available(): + _builtin_ops.append((dist_autograd.get_gradients, "MSG")) + _builtin_ops.append((dist_autograd.backward, "MSG")) + for builtin, aten_op in _builtin_ops: + _builtin_table[id(builtin)] = aten_op return _builtin_table +def _register_builtin(fn, op): + _get_builtin_table()[id(fn)] = op +def _find_builtin(fn): + return _get_builtin_table().get(id(fn)) +from typing import Optional, List import torch +from torch.jit._script import RecursiveScriptModule, ScriptModule +def freeze(mod, preserved_attrs: Optional[List[str]] = None): + r + if not isinstance(mod, ScriptModule): + raise RuntimeError( + "MSG" + "MSG" + ) if mod.training: + raise RuntimeError( + "MSG" + "MSG" + ) preserved_attrs = preserved_attrs if preserved_attrs is not None else [] out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs)) + RecursiveScriptModule._finalize_scriptmodule(out) return out +import contextlib import torch @contextlib.contextmanager +def optimized_execution(should_optimize): + + stored_flag = torch._C._get_graph_executor_optimize() + torch._C._set_graph_executor_optimize(should_optimize) + try: + yield + finally: + torch._C._set_graph_executor_optimize(stored_flag) @contextlib.contextmanager +def fuser(name): + + old_cpu_fuse = torch._C._jit_can_fuse_on_cpu() + old_gpu_fuse = torch._C._jit_can_fuse_on_gpu() + old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() + old_nvfuser_state = torch._C._jit_nvfuser_enabled() + if name == 'MSG': + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(False) + elif name == 'MSG': + old_profiling_executor = torch._C._jit_set_profiling_executor(True) + old_profiling_mode = torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(True) + torch._C._jit_set_texpr_fuser_enabled(True) + torch._C._jit_set_nvfuser_enabled(False) + elif name == 'MSG': + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(True) + else: + raise Exception("MSG") + try: + yield + finally: + if name == 'MSG': + torch._C._jit_set_profiling_executor(old_profiling_executor) + torch._C._jit_set_profiling_mode(old_profiling_mode) + torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse) + torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse) + torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state) + torch._C._jit_set_nvfuser_enabled(old_nvfuser_state) +last_executed_optimized_graph = torch._C._last_executed_optimized_graph +def _graph_for(self, *args, **kwargs): + self(*args, **kwargs) + return last_executed_optimized_graph() +import torch add_stat_value = torch.ops.prim.AddStatValue set_logger = torch._C._logging_set_logger +LockingLogger = torch._C.LockingLogger +AggregationType = torch._C.AggregationType +NoopLogger = torch._C.NoopLogger time_point = torch.ops.prim.TimePoint def build_intlist(data): + return data +def build_tensorlist(data): + return data +def build_doublelist(data): + return data +def build_boollist(data): + return data +def build_tensor_from_id(data): + if isinstance(data, int): + return data +def restore_type_tag(value, type_str): + + + + return value +import inspect +import torch +import collections +import textwrap +import functools +import warnings +from typing import Dict, List, Set, Type import torch._jit_internal as _jit_internal +from torch.jit.frontend import get_default_args, get_jit_def, get_class_properties +from torch.jit._builtins import _find_builtin +from torch.nn import Module +from torch._six import get_function_from_type, bind_method +ScriptMethodStub = collections.namedtuple('MSG', ('MSG', 'MSG', 'MSG')) +PropertyStub = collections.namedtuple('MSG', ('MSG', 'MSG')) ignored_attributes = [ + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", +] def make_stub(func, name): + rcb = _jit_internal.createResolutionCallbackFromClosure(func) + ast = get_jit_def(func, name, self_name="MSG") + return ScriptMethodStub(rcb, ast, func) def make_stub_from_method(nn_module, method_name): + func = getattr(nn_module, method_name) + if isinstance(func, ScriptMethodStub): + return func + + + + + + + + + return make_stub(func, method_name) +def make_stubs_from_exported_methods(mod): + stubs = [] + for name in dir(mod): + item = getattr(mod, name, None) + if ( + _jit_internal.get_torchscript_modifier(item) + is _jit_internal.FunctionModifiers.EXPORT + ): + stubs.append(make_stub_from_method(mod, name)) return stubs +_constant_types = (bool, float, int, str, type(None), torch.device, torch.layout, torch.dtype) def _get_valid_constant(attr, v, owner_type): + if isinstance(v, _constant_types): + return v + elif isinstance(v, tuple) or isinstance(v, list): + return tuple(_get_valid_constant(attr, x, owner_type) for x in v) + constants = "MSG".join(torch.typename(typ) for typ in _constant_types) + raise TypeError(textwrap.dedent(.format(torch.typename(type(v)), owner_type, attr, constants))) +class SourceContext(torch._C._jit_tree_views.SourceRangeFactory): + def __init__(self, source, filename, file_lineno, leading_whitespace_len): + super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len) +def infer_concrete_type_builder(nn_module, share_types=True): + + concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module)) + if isinstance(nn_module, (torch.nn.ModuleDict)): + concrete_type_builder.set_module_dict() + if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)): + concrete_type_builder.set_module_list() class_annotations = getattr(nn_module, 'MSG', {}) + def infer_type(name, item): + if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["MSG"]: + attr_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range()) + elif isinstance(item, torch.jit.Attribute): + attr_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range()) + else: + attr_type = torch._C._jit_try_infer_type(item) + return attr_type added_names = set() for name, item in nn_module._parameters.items(): + assert item is None or isinstance(item, torch.Tensor) + attr_type = infer_type(name, item) + concrete_type_builder.add_attribute(name, attr_type, True, False) + added_names.add(name) for name, item in nn_module._buffers.items(): + assert item is None or isinstance(item, torch.Tensor) + attr_type = infer_type(name, item) + concrete_type_builder.add_attribute(name, attr_type, False, True) + added_names.add(name) for name, item in nn_module._modules.items(): + attr_type = infer_type(name, item) + if item is None: + concrete_type_builder.add_attribute(name, attr_type, False, False) + continue + if attr_type is not None: + assert attr_type.is_interface_type() + sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(attr_type) + else: + sub_concrete_type = get_module_concrete_type(item, share_types) + concrete_type_builder.add_module(name, sub_concrete_type) added_names.add(name) + constants_set = getattr(nn_module, "MSG", set()) + for name, ann in class_annotations.items(): + if torch._jit_internal.is_final(ann): + constants_set.add(name) for name in constants_set: + if name in added_names: + if name in nn_module._modules: + hint = "MSG" + elif name in nn_module._buffers: + hint = "MSG" + elif name in nn_module._parameters: + hint = "MSG" + else: + raise AssertionError("MSG") warnings.warn("MSG" + "MSG".format(name, hint)) + continue + if not hasattr(nn_module, name): + warnings.warn("MSG" + "MSG" + "MSG".format(name)) + continue + value = getattr(nn_module, name) + concrete_type_builder.add_constant(name, _get_valid_constant(name, value, type(nn_module).__name__)) + added_names.add(name) + overloads = getattr(nn_module, "MSG", {}) + + overloads.update(get_overload_name_mapping(get_overload_annotations(nn_module))) + for name, overloaded_names in overloads.items(): + concrete_type_builder.add_overload(name, overloaded_names) for name, value in nn_module.__dict__.items(): + if name in ignored_attributes or name.startswith("MSG"): + continue if name in added_names: + continue if inspect.isfunction(value): + try: + scripted_fn = torch.jit.script(value) + concrete_type_builder.add_function_attribute( + name, + torch._C._jit_try_infer_type(scripted_fn), + value) + except Exception as e: + hint = ("MSG" + "MSG" + "MSG").format(e) + concrete_type_builder.add_failed_attribute(name, hint) + pass continue builtin_symbol_name = _find_builtin(value) + if builtin_symbol_name: + concrete_type_builder.add_builtin_function(name, builtin_symbol_name) + continue if isinstance(value, torch.jit.ScriptFunction): + concrete_type_builder.add_function_attribute( + name, + torch._C._jit_try_infer_type(value), + value) + continue attr_type = infer_type(name, value) + if attr_type is not None: + concrete_type_builder.add_attribute(name, attr_type, False, False) + else: + hint = ("MSG" + "MSG" + "MSG").format(torch.typename(type(value))) + concrete_type_builder.add_failed_attribute(name, hint) return concrete_type_builder class ConcreteTypeStore(object): + type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]] + methods_compiled: Set[torch._C.ConcreteModuleType] def __init__(self): + self.type_store = {} + self.methods_compiled = set() def get_or_create_concrete_type(self, nn_module): + concrete_type_builder = infer_concrete_type_builder(nn_module) nn_module_type = type(nn_module) + if nn_module_type not in self.type_store: + self.type_store[nn_module_type] = [] known_types = self.type_store[nn_module_type] + for known_type in known_types: + if known_type.equals(concrete_type_builder): + return known_type concrete_type = concrete_type_builder.build() + self.type_store[nn_module_type].append(concrete_type) + return concrete_type concrete_type_store = ConcreteTypeStore() +def create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs): + method_defs = [m.def_ for m in method_stubs] + method_rcbs = [m.resolution_callback for m in method_stubs] + method_defaults = [get_default_args(m.original_method) for m in method_stubs] property_defs = [p.def_ for p in property_stubs] + property_rcbs = [p.resolution_callback for p in property_stubs] concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults) +def get_module_concrete_type(nn_module, share_types=True): + + assert isinstance(nn_module, Module) + if isinstance(nn_module, torch.jit.ScriptModule) and \ + hasattr(nn_module, "MSG"): + return nn_module._concrete_type if share_types: + concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module) + else: + concrete_type_builder = infer_concrete_type_builder(nn_module, share_types) + concrete_type_builder.set_poisoned() + concrete_type = concrete_type_builder.build() return concrete_type def create_script_module(nn_module, stubs_fn, share_types=True): + + assert not isinstance(nn_module, torch.jit.RecursiveScriptModule) + check_module_initialized(nn_module) + concrete_type = get_module_concrete_type(nn_module, share_types) + return create_script_module_impl(nn_module, concrete_type, stubs_fn) def create_script_module_impl(nn_module, concrete_type, stubs_fn): + + cpp_module = torch._C._create_module_with_type(concrete_type.jit_type) + method_stubs = stubs_fn(nn_module) + property_stubs = get_property_stubs(nn_module) def init_fn(script_module): + for name, (attr_type, is_param) in concrete_type.get_attributes().items(): + orig_value = getattr(nn_module, name) + orig_value = orig_value.value if isinstance(orig_value, torch.jit.Attribute) else orig_value + cpp_module.setattr(name, orig_value) for name, sub_concrete_type in concrete_type.get_modules(): + orig_value = getattr(nn_module, name) + assert isinstance(orig_value, Module), "MSG".format(type(orig_value)) + module_type = sub_concrete_type.jit_type + if isinstance(module_type, torch._C.InterfaceType): + scripted = interface_script(module_type, orig_value) + elif isinstance(orig_value, torch.jit.ScriptModule): + scripted = orig_value + else: + scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn) cpp_module.setattr(name, scripted) + script_module._modules[name] = scripted for name in dir(nn_module): + item = getattr(nn_module, name, None) + if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item): + unbound_function = getattr(type(nn_module), name) + bound_method = unbound_function.__get__(script_module) + setattr(script_module, name, bound_method) script_module._concrete_type = concrete_type + script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) + if concrete_type not in concrete_type_store.methods_compiled: + create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs) + torch._C._run_emit_module_hook(cpp_module) + concrete_type_store.methods_compiled.add(concrete_type) + if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)) and \ + 'MSG' not in cpp_module._method_names(): + script_module.define("MSG".format(len(nn_module))) + if isinstance(nn_module, torch.nn.ModuleDict) and \ + 'MSG' not in cpp_module._method_names(): + if len(nn_module.keys()): + keys = repr(list(nn_module.keys())) + script_module.define("MSG".format(keys)) + else: + script_module.define("MSG") + + for method_stub in method_stubs: + if method_stub.original_method is None: + continue name = method_stub.original_method.__name__ + if name != method_stub.def_.name().name: + continue + script_method = cpp_module._get_method(name) wrapped_script_method = functools.wraps(method_stub.original_method)(script_method) script_module.__dict__[name] = wrapped_script_method + + for property_stub in property_stubs: + property_name = property_stub.def_.name().name + fget = cpp_module._get_method(property_stub.def_.getter_name().name) + setter_name = property_stub.def_.setter_name() + fset = cpp_module._get_method(setter_name.name) if setter_name else None + script_module.__dict__[property_name] = property(property_name, fget, fset) + + for name in dir(nn_module): + item = getattr(nn_module, name, None) + if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER: + add_python_attr_to_scripted_model(script_module, nn_module, name) return script_module def script_model_defines_attr(script_model, attr): + script_attr = getattr(script_model, attr, None) + if script_attr is None: + return False + default_attr = get_function_from_type(torch.jit.RecursiveScriptModule, attr) + if default_attr is None: + return False + return script_attr != default_attr def add_python_attr_to_scripted_model(script_model, orig, attr): + if hasattr(orig, attr) and script_model_defines_attr(script_model, attr): + setattr(script_model, attr, getattr(orig, attr)) def get_overload_annotations(mod): + + overloads = {} for name in dir(type(mod)): + item = getattr(mod, name, None) + if not callable(item): + continue if hasattr(item, "MSG") and item.__module__ is not None: + method_overloads = _jit_internal._get_overloaded_methods(item, mod.__class__) + if method_overloads is None: + continue names = [name + "MSG" + str(i) for i in range(len(method_overloads))] + overloads[item] = list(zip(names, method_overloads)) return overloads def get_overload_name_mapping(overload_info): + + + overload_name_mappings: Dict[str, List[str]] = {} + for orig_fn, overloads in overload_info.items(): + original_name = orig_fn.__name__ + if original_name not in overload_name_mappings: + overload_name_mappings[original_name] = [] for overload_name, _ in overloads: + overload_name_mappings[original_name].append(overload_name) + return overload_name_mappings def _check_no_signature(func): + signature = torch.jit.annotations.get_signature(func, None, _jit_internal.fake_range(), inspect.ismethod(func)) + if signature is None: + qual_name = _jit_internal._qualified_name(func) + raise RuntimeError("MSG".format(qual_name)) def make_stubs_for_overloads(overload_info): + overload_stubs = [] + for orig_fn, overloads in overload_info.items(): + orig_ast = get_jit_def(orig_fn, orig_fn.__name__, self_name="MSG") + for overload_name, overload_fn in overloads: + _check_no_signature(overload_fn) + over_ast = get_jit_def(overload_fn, overload_fn.__name__, self_name="MSG") + new_ast = torch._C._replace_overloaded_method_decl(over_ast.decl(), orig_ast, overload_name) + _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn) + overload_stubs.append(ScriptMethodStub(_rcb, new_ast, overload_fn)) + return overload_stubs def check_module_initialized(mod): + assert isinstance(mod, torch.nn.Module) + if not hasattr(mod, 'MSG'): + raise RuntimeError("MSG" + .format(torch.typename(type(mod)))) def infer_methods_to_compile(nn_module): + + check_module_initialized(nn_module) methods: List[str] = [] + if hasattr(nn_module, 'MSG') and not _jit_internal.is_ignored_fn(nn_module.forward): + forward_func = getattr(nn_module.forward, "MSG", None) + module_forward = get_function_from_type(torch.nn.Module, "MSG") + if forward_func != module_forward: + methods = ['MSG'] exported = [] + for name in dir(nn_module): + item = getattr(nn_module, name, None) + if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT: + exported.append(name) methods = methods + exported overload_name_mappings = dict(getattr(nn_module, "MSG", {})) + overload_info = get_overload_annotations(nn_module) + overload_name_mappings.update(get_overload_name_mapping(overload_info)) + overload_stubs = make_stubs_for_overloads(overload_info) nn_module.__overloads__ = overload_name_mappings + def ignore_overloaded(method_name): + return method_name not in overload_name_mappings filtered_methods = filter(ignore_overloaded, methods) + + uniquer: Set[str] = set() + uniqued_methods = [] + for name in filtered_methods: + if name in uniquer: + continue + uniqued_methods.append(name) + uniquer.add(name) stubs = [] + for method in uniqued_methods: + stubs.append(make_stub_from_method(nn_module, method)) + return overload_stubs + stubs +def get_property_stubs(nn_module): + + module_ty = type(nn_module) + properties_asts = get_class_properties(module_ty, self_name="MSG") + rcbs = {} for name in dir(module_ty): + item = getattr(module_ty, name, None) + if isinstance(item, property): + if not item.fget: + raise RuntimeError(f'MSG') rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(item.fget) stubs = [PropertyStub(rcbs[ast.name().name], ast) for ast in properties_asts] + return stubs +def interface_script(mod_interface, nn_module): + + if isinstance(nn_module, torch.jit.ScriptModule): + return nn_module check_module_initialized(nn_module) def infer_interface_methods_to_compile(nn_module): + stubs = [] + for method in mod_interface.getMethodNames(): + stubs.append(make_stub_from_method(nn_module, method)) + return stubs return create_script_module(nn_module, infer_interface_methods_to_compile) def try_compile_fn(fn, loc): + if _jit_internal.is_ignored_fn(fn): + return None if isinstance(fn, torch.nn.Module): + return None if not inspect.isfunction(fn) and not inspect.ismethod(fn): + raise RuntimeError("MSG" + "MSG" + "MSG".format(fn, fn)) + + + rcb = _jit_internal.createResolutionCallbackFromClosure(fn) + return torch.jit.script(fn, _rcb=rcb) def wrap_cpp_module(cpp_module): + + def init_fn(script_module): + for name, cpp_module in torch._C.ModuleDict(script_module._c).items(): + setattr(script_module, name, wrap_cpp_module(cpp_module)) + script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(script_module._c._type()) + return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) def compile_unbound_method(concrete_type, fn): + if _jit_internal.is_ignored_fn(fn): + return None + stub = make_stub(fn, fn.__name__) + with torch._jit_internal._disable_emit_hooks(): + create_methods_and_properties_from_stubs(concrete_type, (stub,), ()) + return stub def lazy_bind(concrete_type, unbound_method): + + def lazy_binding_method(cpp_module, *args): + def init_fn(script_module): + orig_class = concrete_type.py_class for name in dir(orig_class): + item = getattr(orig_class, name, None) + if _jit_internal.is_ignored_fn(item): + setattr(script_module, name, item) for name, value in concrete_type.get_constants().items(): + setattr(script_module, name, value) script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) + method = bind_method(unbound_method, script_module, torch.jit.RecursiveScriptModule) + return method(*args) + lazy_binding_method.original_fn = unbound_method + lazy_binding_method.__name__ = unbound_method.__name__ + torch._jit_internal.copy_torchscript_modifier(unbound_method, lazy_binding_method) return lazy_binding_method import functools +import collections +import inspect +import copy +import pickle +import warnings +from typing import Any, Dict +import torch +import torch._jit_internal as _jit_internal +from torch.utils import set_module +from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module, infer_methods_to_compile +from torch.nn import Module +from torch.jit._state import _enabled +from torch.jit._builtins import _register_builtin +from torch._six import with_metaclass, get_function_from_type +from torch.jit.frontend import get_jit_def, get_default_args, get_jit_class_def +from torch._jit_internal import _qualified_name +from torch.jit._fuser import _graph_for +from torch.jit._state import ( + _try_get_jit_cached_function, + _try_get_jit_cached_overloads, + _set_jit_function_cache, + _set_jit_overload_cache, +) torch._C.ScriptMethod.graph_for = _graph_for +torch._C.ScriptFunction.graph_for = _graph_for +ScriptFunction = torch._C.ScriptFunction +ScriptFunction.__doc__ = +set_module(ScriptFunction, "MSG") +if _enabled: + Attribute = collections.namedtuple("MSG", ["MSG", "MSG"]) +else: def Attribute(value, type): + return value +def _is_new_style_class(cls): + if hasattr(cls, "MSG"): + return "MSG" in dir(cls) or hasattr(cls, "MSG") +def _compile_and_register_class(obj, rcb, qualified_name): + ast = get_jit_class_def(obj, obj.__name__) + defaults = torch.jit.frontend.get_default_args_for_class(obj) + torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb) + torch.jit._state._add_script_class(obj, qualified_name) class OrderedDictWrapper(object): + def __init__(self, _c): + self._c = _c def keys(self): + return [k for k, v in self.items()] def values(self): + return [v for k, v in self.items()] def __len__(self): + return len(self.values()) def __delitem__(self, k): + raise RuntimeError("MSG") def items(self): + return self._c.items() def __setitem__(self, k, v): + if k not in self: + raise RuntimeError( + "MSG" + "MSG".format(k) + ) + self._c.setattr(k, v) def __contains__(self, k): + return self._c.contains(k) def __getitem__(self, k): + if k not in self: + raise KeyError(k) + return self._c.getattr(k) +class OrderedModuleDict(OrderedDictWrapper): + def __init__(self, module, python_dict): + super(OrderedModuleDict, self).__init__(torch._C.ModuleDict(module)) + self._python_modules = python_dict def items(self): + r = self._python_modules.items() + return r def __contains__(self, k): + return k in self._python_modules def __setitem__(self, k, v): + if isinstance(v, ScriptModule): + self._c.setattr(k, v) + self._python_modules[k] = v + else: + raise RuntimeError( + "MSG" + "MSG".format(k, v) + ) def __getitem__(self, k): + return self._python_modules[k] +class ScriptMeta(type): + def __init__(cls, name, bases, attrs): + cls._methods: Dict[str, Any] = {} + cls._constants_set = set(getattr(cls, "MSG", ())) + for base in reversed(bases): + for k, v in getattr(base, "MSG", {}).items(): + cls._methods[k] = v + base_constants = getattr(base, "MSG", set()) + cls._constants_set = cls._constants_set.union(base_constants) for k, v in sorted(attrs.items()): + if isinstance(v, ScriptMethodStub): + delattr(cls, k) + cls._methods[v.original_method.__name__] = v if getattr(cls, "MSG", False): + return super(ScriptMeta, cls).__init__(name, bases, attrs) original_init = getattr(cls, "MSG", lambda self: None) @functools.wraps(original_init) + def init_then_script(self, *args, **kwargs): + num_methods = len(cls._methods) + original_init(self, *args, **kwargs) + added_methods_in_init = len(cls._methods) > num_methods if type(self) == cls: def make_stubs(module): + cls = type(module) + if hasattr(cls, "MSG"): + return [v for k, v in sorted(cls._methods.items())] + else: + return infer_methods_to_compile(module) self.__dict__[ + "MSG" + ] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init) concrete_type = self._actual_script_module._concrete_type + for name in concrete_type.get_attributes(): + delattr(self, name) + for name, _ in concrete_type.get_modules(): + delattr(self, name) + for name in ("MSG", "MSG", "MSG"): + delattr(self, name) cls.__init__ = init_then_script + return super(ScriptMeta, cls).__init__(name, bases, attrs) +class _CachedForward(object): + def __get__(self, obj, cls): + return self.__getattr__("MSG") +class ScriptWarning(Warning): + pass +def script_method(fn): + if not _enabled: + return fn + + + + + + + + + + + + + _rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2) + ast = get_jit_def(fn, fn.__name__, self_name="MSG") + return ScriptMethodStub(_rcb, ast, fn) +class ConstMap: + def __init__(self, const_mapping): + self.const_mapping = const_mapping def __getattr__(self, attr): + return self.const_mapping[attr] +if _enabled: + + + + + + + class ScriptModule(with_metaclass(ScriptMeta, Module)): + __jit_unused_properties__ = ['MSG', 'MSG', 'MSG', 'MSG', 'MSG'] def __init__(self): + super(ScriptModule, self).__init__() forward = _CachedForward() def __getattr__(self, attr): + if "MSG" not in self.__dict__: + return super(ScriptModule, self).__getattr__(attr) + return getattr(self._actual_script_module, attr) def __setattr__(self, attr, value): + if "MSG" not in self.__dict__: + if isinstance(value, Attribute): + if "MSG" not in self.__class__.__dict__: + self.__class__.__annotations__ = {} + self.__annotations__[attr] = value.type + value = value.value + return super(ScriptModule, self).__setattr__(attr, value) setattr(self._actual_script_module, attr, value) def define(self, src): + if "MSG" in self.__dict__: + return self._actual_script_module.define(src) rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1) + ast = torch._C._parse_source_def(src) + self._methods[ast.name().name] = ScriptMethodStub(rcb, ast, None) def _replicate_for_data_parallel(self): + return self._actual_script_module._replicate_for_data_parallel() class RecursiveScriptModule(ScriptModule): + r + _disable_script_meta = True def __init__(self, cpp_module): + self.__dict__["MSG"] = True + self._c = cpp_module + super(RecursiveScriptModule, self).__init__() + delattr(self, "MSG") @staticmethod + def _construct(cpp_module, init_fn): + script_module = RecursiveScriptModule(cpp_module) + init_fn(script_module) RecursiveScriptModule._finalize_scriptmodule(script_module) + return script_module @staticmethod + def _finalize_scriptmodule(script_module): + script_module._parameters = OrderedDictWrapper( + torch._C.ParameterDict(script_module._c) + ) + script_module._buffers = OrderedDictWrapper( + torch._C.BufferDict(script_module._c) + ) + script_module._modules = OrderedModuleDict( + script_module._c, script_module._modules + ) + script_module._initializing = False def _reconstruct(self, cpp_module): + self.__init__(cpp_module) self._concrete_type = torch._C.ConcreteModuleType.from_jit_type( + self._c._type() + ) modules = {} + for name, cpp_module in torch._C.ModuleDict(self._c).items(): + modules[name] = wrap_cpp_module(cpp_module) + self._modules = OrderedModuleDict(self._c, modules) self._parameters = OrderedDictWrapper(torch._C.ParameterDict(self._c)) + self._buffers = OrderedDictWrapper(torch._C.BufferDict(self._c)) self.__dict__ = { + k: v + for k, v in self.__dict__.items() + if not isinstance(v, torch._C.ScriptMethod) + } + self.__dict__["MSG"] = False @property + def graph(self): + r + return self.forward.graph @property + def inlined_graph(self): + r + return self.forward.inlined_graph @property + def code(self): + r + return self.forward.code @property + def code_with_constants(self): + r + r = self.forward.code_with_constants + return (r[0], ConstMap(r[1])) def save(self, *args, **kwargs): + r + return self._c.save(*args, **kwargs) def _save_for_lite_interpreter(self, *args, **kwargs): + r + return self._c._save_for_mobile(*args, **kwargs) def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs): + return self._c._save_to_buffer_for_mobile(*args, **kwargs) def save_to_buffer(self, *args, **kwargs): + return self._c.save_to_buffer(*args, **kwargs) def get_debug_state(self, *args, **kwargs): + return self._c.get_debug_state() def extra_repr(self): + return "MSG".format(self.original_name) def graph_for(self, *args, **kwargs): + return self.forward.graph_for(*args, **kwargs) @property + def original_name(self): + if type(self) == str(self._c._type().name()): + return "MSG" + return str(self._c._type().name()) def define(self, src): + rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1) + self._c._define(self._concrete_type, src, rcb) def __getattr__(self, attr): + if "MSG" not in self.__dict__: + raise RuntimeError( + "MSG" + ) if self._initializing: + return super(RecursiveScriptModule, self).__getattr__(attr) if attr in self._modules: + return self._modules[attr] + elif self._c.hasattr(attr): + return self._c.getattr(attr) + elif self._c._has_method(attr): + script_method = self._c._get_method(attr) + self.__dict__[attr] = script_method + return script_method return super(RecursiveScriptModule, self).__getattr__(attr) def __setattr__(self, attr, value): + if self._initializing: + return super(RecursiveScriptModule, self).__setattr__(attr, value) if attr in self._modules: + self._modules[attr] = value + elif self._c.hasattr(attr): + self._c.setattr(attr, value) + elif ( + hasattr(self, "MSG") + and attr in self._concrete_type.get_constants().keys() + ): + raise AttributeError( + "MSG".format( + attr, value + ) + ) + else: + return super(RecursiveScriptModule, self).__setattr__(attr, value) def __getstate__(self): + raise pickle.PickleError( + "MSG" + + "MSG" + + "MSG" + ) def __copy__(self): + return torch.jit._recursive.wrap_cpp_module(copy.copy(self._c)) def __deepcopy__(self, memo): + return torch.jit._recursive.wrap_cpp_module(copy.deepcopy(self._c, memo)) def forward_magic_method(self, method_name, *args, **kwargs): + self_method = getattr(self, method_name) + if getattr(self_method, "MSG", None) == getattr( + RecursiveScriptModule, method_name + ): + raise NotImplementedError() + return self_method(*args, **kwargs) def __iter__(self): + return self.forward_magic_method("MSG") def __getitem__(self, idx): + return self.forward_magic_method("MSG", idx) def __len__(self): + return self.forward_magic_method("MSG") def __contains__(self, key): + return self.forward_magic_method("MSG", key) def __dir__(self): + self_method = self.__dir__ + if self_method.__func__ == get_function_from_type( + RecursiveScriptModule, "MSG" + ): + return super(RecursiveScriptModule, self).__dir__() + return self_method() def __bool__(self): + self_method = self.__bool__ + if self_method.__func__ == get_function_from_type( + RecursiveScriptModule, "MSG" + ): + return True + return self_method() def _replicate_for_data_parallel(self): + def init_fn(script_module): + return return RecursiveScriptModule._construct( + self._c._replicate_for_data_parallel(), init_fn + ) + + + + + for name, item in RecursiveScriptModule.__dict__.items(): + if not callable(item) and not isinstance(item, property): + continue + if name.startswith("MSG") or hasattr(ScriptModule, name): + continue + setattr(ScriptModule, name, item) def _get_methods(cls): + import inspect return inspect.getmembers( + cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x) + ) _compiled_methods_allowlist = { + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + "MSG", + } def _make_fail(name): + def fail(self, *args, **kwargs): + raise RuntimeError(name + "MSG") return fail for name, method in _get_methods(torch.nn.Module): + if name.startswith("MSG"): + continue + if ( + name not in RecursiveScriptModule.__dict__ + and name not in _compiled_methods_allowlist + ): + setattr(RecursiveScriptModule, method.__name__, _make_fail(name)) +else: + + class ScriptModule(torch.nn.Module): + def __init__(self, arg=None): + super().__init__() class RecursiveScriptModule(ScriptModule): + def __init__(self, arg=None): + super().__init__() +def script(obj, optimize=None, _frames_up=0, _rcb=None): + r + if not _enabled: + return obj if optimize is not None: + warnings.warn( + "MSG" + ) + if isinstance(obj, ScriptModule): + return obj if isinstance(obj, torch.nn.Module): + return torch.jit._recursive.create_script_module( + obj, torch.jit._recursive.infer_methods_to_compile + ) qualified_name = _qualified_name(obj) + if inspect.isclass(obj): + if issubclass(obj, torch.nn.Module): + raise RuntimeError( + "MSG" + "MSG" + "MSG".format(obj) + ) if not _is_new_style_class(obj): + raise RuntimeError( + "MSG" + "MSG" + ) + if len(obj.mro()) > 2: + raise RuntimeError( + "MSG" + "MSG" + ) + if _rcb is None: + _rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) + _compile_and_register_class(obj, _rcb, qualified_name) + return obj + else: + if hasattr(obj, "MSG"): + obj = obj.__original_fn + _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) _check_directly_compile_overloaded(obj) + maybe_already_compiled_fn = _try_get_jit_cached_function(obj) + if maybe_already_compiled_fn: + return maybe_already_compiled_fn + ast = get_jit_def(obj, obj.__name__) + if _rcb is None: + _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) + fn = torch._C._jit_script_compile( + qualified_name, ast, _rcb, get_default_args(obj) + ) + fn.__doc__ = obj.__doc__ + _set_jit_function_cache(obj, fn) + return fn +def _check_overload_defaults(impl_defaults, overload_defaults, loc): + for name, overload_value in overload_defaults.items(): + if name not in impl_defaults or impl_defaults[name] != overload_value: + raise torch.jit.frontend.FrontendError( + loc, + "MSG" + "MSG" + "MSG".format(name=name), + ) +def _compile_function_with_overload(overload_fn, qual_name, impl_fn): + overload_decl = get_jit_def(overload_fn, overload_fn.__name__).decl() + overload_signature = torch.jit.annotations.get_signature( + overload_fn, None, None, inspect.ismethod(overload_fn) + ) + impl_ast = get_jit_def(impl_fn, impl_fn.__name__) + overload_defaults = get_default_args(overload_fn) + implementation_defaults = get_default_args(impl_fn) + _rcb = _jit_internal.createResolutionCallbackFromClosure(impl_fn) + _check_overload_defaults( + implementation_defaults, overload_defaults, overload_decl.range() + ) + fn = torch._C._jit_script_compile_overload( + qual_name, + overload_decl, + impl_ast, + _rcb, + implementation_defaults, + overload_signature, + ) + return fn +def _get_overloads(obj): + + existing_compiled_fns = _try_get_jit_cached_overloads(obj) + qual_name = _qualified_name(obj) + uncompiled_overloads = _jit_internal._get_fn_overloads(qual_name) + if uncompiled_overloads is None: + return existing_compiled_fns compiled_fns = [] + for overload_fn in uncompiled_overloads: + compiled_fns.append( + _compile_function_with_overload(overload_fn, qual_name, obj) + ) if existing_compiled_fns: + compiled_fns = existing_compiled_fns + compiled_fns + _set_jit_overload_cache(obj, compiled_fns) + _jit_internal._clear_fn_overloads(qual_name) + return compiled_fns +def _check_directly_compile_overloaded(obj): + qual_name = _qualified_name(obj) + if _jit_internal._get_fn_overloads(qual_name) or _try_get_jit_cached_overloads(obj): + raise RuntimeError( + "MSG" + "MSG" + "MSG".format(qual_name) + ) +def interface(obj): + if not inspect.isclass(obj): + raise RuntimeError("MSG") + if not _is_new_style_class(obj): + raise RuntimeError("MSG") + + + + is_module_interface = issubclass(obj, torch.nn.Module) and len(obj.mro()) == 3 if not is_module_interface and len(obj.mro()) > 2: + raise RuntimeError( + "MSG" + "MSG" + ) qualified_name = _qualified_name(obj) + rcb = _jit_internal.createResolutionCallbackFromFrame(1) + + + + ast = get_jit_class_def(obj, obj.__name__) + torch._C._jit_script_interface_compile( + qualified_name, ast, rcb, is_module_interface + ) + obj.__torch_script_interface__ = True + return obj +def _recursive_compile_class(obj, loc): + _qual_name = _qualified_name(obj) + + + error_stack = torch._C.CallStack(_qual_name, loc) + rcb = _jit_internal.createResolutionCallbackForClassMethods(obj) + _compile_and_register_class(obj, rcb, _qual_name) +class CompilationUnit(object): + def __init__(self, lang=None, _frames_up=0): + self._c = torch._C.CompilationUnit() + if lang is not None: + self.define(lang, _frames_up=_frames_up + 1) def define(self, lang, rcb=None, _frames_up=0): + if not rcb: + rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) + self._c.define(lang, rcb) def __getattr__(self, attr): + r = self._c.find_function(attr) + if r is None: + raise AttributeError("MSG".format(attr)) + return r +def _unwrap_optional(x): + assert x is not None, "MSG" + return x +_register_builtin(_unwrap_optional, "MSG") +_register_builtin(_jit_internal.is_scripting, "MSG") import os +import pathlib import torch +from torch._six import string_classes +from torch.jit._recursive import wrap_cpp_module +from torch.serialization import validate_cuda_device +def save(m, f, _extra_files=None): + r + if _extra_files is None: + _extra_files = {} + if isinstance(f, str) or isinstance(f, pathlib.Path): + m.save(f, _extra_files=_extra_files) + else: + ret = m.save_to_buffer(_extra_files=_extra_files) + f.write(ret) +def load(f, map_location=None, _extra_files=None): + r + if isinstance(f, string_classes): + if not os.path.exists(f): + raise ValueError("MSG".format(f)) + if os.path.isdir(f): + raise ValueError("MSG".format(f)) map_location = validate_map_location(map_location) + if _extra_files is None: + _extra_files = {} cu = torch._C.CompilationUnit() + if isinstance(f, str) or isinstance(f, pathlib.Path): + cpp_module = torch._C.import_ir_module(cu, f, map_location, _extra_files) + else: + cpp_module = torch._C.import_ir_module_from_buffer( + cu, f.read(), map_location, _extra_files + ) + return wrap_cpp_module(cpp_module) +def validate_map_location(map_location=None): + if isinstance(map_location, str): + map_location = torch.device(map_location) + elif not (map_location is None or isinstance(map_location, torch.device)): + raise ValueError( + "MSG" + "MSG" + str(type(map_location)) + ) if str(map_location).startswith("MSG"): + validate_cuda_device(map_location) return map_location import torch +import os +import weakref class EnabledProxy: + def __init__(self): + self.enabled = self.parse_env( + "MSG", True, "MSG", "MSG" + ) def parse_env(self, name, default, true_message, false_message): + value = os.environ.get(name) + if value is None: + return default + if value.lower() in {"MSG", "MSG", "MSG"}: + return True + elif value.lower() in {"MSG", "MSG", "MSG"}: + return False + if value == "MSG": + print(true_message) + return True + elif value == "MSG": + print(false_message) + return False + raise ValueError("MSG".format(name)) def __bool__(self): + return self.enabled +_enabled = EnabledProxy() +def disable(): + _enabled.enabled = False +def enable(): + _enabled.enabled = True _python_cu = torch._C.CompilationUnit() _script_classes = {} def _add_script_class(cls, name): + global _script_classes + _script_classes[name] = cls +def _get_script_class(name): + global _script_classes + if name not in _script_classes: + return None + return _script_classes[name] +_jit_caching_layer: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() +_jit_function_overload_caching: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() def _try_get_jit_cached_overloads(key): + qual_names = _jit_function_overload_caching.get(key, None) + if qual_names: + return [_python_cu.find_function(qual_name) for qual_name in qual_names] + else: + return None def _set_jit_overload_cache(key, compiled_fns): + _jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns] def _try_get_jit_cached_function(key): + if getattr(key, "MSG", False) is True: + return None + qual_name = _jit_caching_layer.get(key, None) + if qual_name: + return _python_cu.find_function(qual_name) + else: + return None def _set_jit_function_cache(key, value): + + assert isinstance(value, torch.jit.ScriptFunction) + _jit_caching_layer[key] = value.qualified_name import torch import copy +import os +import contextlib +import functools +import warnings +import inspect +import re +from typing import Any, Dict, List, Optional, Set from torch.jit._state import _python_cu, _enabled +from torch.jit._script import ScriptModule, _CachedForward, script +from torch._jit_internal import _qualified_name +from torch.autograd import function +from torch.nn import Module _flatten = torch._C._jit_flatten +_unflatten = torch._C._jit_unflatten +def _create_interpreter_name_lookup_fn(frames_up=1): + def _get_interpreter_name_for_var(var): + frame = inspect.currentframe() + if not frame: + raise RuntimeError("MSG") i = 0 + while i < frames_up + 1: + frame = frame.f_back + if not frame: + raise RuntimeError("MSG") + i += 1 f_locals = frame.f_locals + f_globals = frame.f_globals for k, v in f_locals.items(): + if isinstance(v, torch.Tensor) and var is v: + return k if k != "MSG" else "MSG" + return "MSG" return _get_interpreter_name_for_var +def _unique_state_dict(module, keep_vars=False): + + + + state_dict = module.state_dict(keep_vars=True) + filtered_dict = type(state_dict)() + seen_ids: Set[int] = set() + for k, v in state_dict.items(): + if id(v) in seen_ids: + continue + seen_ids.add(id(v)) + if keep_vars: + filtered_dict[k] = v + else: + filtered_dict[k] = v.detach() + return filtered_dict +class ONNXTracedModule(torch.nn.Module): + def __init__( + self, + inner, + strict=True, + force_outplace=False, + return_inputs=False, + return_inputs_states=False, + ): + super(ONNXTracedModule, self).__init__() + self.inner = inner + self.strict = strict + self._force_outplace = force_outplace + self._return_inputs = return_inputs + self._return_inputs_states = return_inputs_states def forward(self, *args: torch.Tensor): + in_vars, in_desc = _flatten(args) + module_state = list(_unique_state_dict(self, keep_vars=True).values()) ret_inputs = [] + inputs_states = [] + outs = [] def wrapper(*args): + in_args: List[torch.Tensor] = [] + for i in range(len(in_vars)): + if not isinstance(args[i], torch.Tensor): + raise RuntimeError('MSG') + in_args.append(args[i]) trace_inputs = _unflatten(in_args, in_desc) ret_inputs.append( + tuple(x.clone(memory_format=torch.preserve_format) for x in args) + ) + if self._return_inputs_states: + inputs_states.append(_unflatten(in_args, in_desc)) + outs.append(self.inner(*trace_inputs)) + if self._return_inputs_states: + inputs_states[0] = (inputs_states[0], trace_inputs) + out_vars, _ = _flatten(outs) + if len(out_vars) == 1: + return out_vars[0] + else: + return tuple(out_vars) graph, out = torch._C._create_graph_by_tracing( + wrapper, + in_vars + module_state, + _create_interpreter_name_lookup_fn(), + self.strict, + self._force_outplace, + ) if self._return_inputs: + return graph, outs[0], ret_inputs[0] + if self._return_inputs_states: + return graph, outs[0], inputs_states[0] + else: + return graph, outs[0] +def _clone_inputs(args): + def clone_input(a): + if a is None: + return None + elif isinstance(a, torch.Tensor): + v = ( + a.detach() + .clone(memory_format=torch.preserve_format) + .requires_grad_(a.requires_grad) + ) + if a.grad is not None: + v.grad = clone_input(v.grad) + return v + else: + return a.clone(memory_format=torch.preserve_format) return function._nested_map( + lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="MSG" + )(args) _JIT_TIME = os.environ.get("MSG", False) +_JIT_DISABLE = os.environ.get("MSG", False) +_JIT_STATS = os.environ.get("MSG", False) +@contextlib.contextmanager +def _time(trace_name, name, time=True): + if (not _JIT_TIME and not time) or not torch.cuda.is_available(): + yield + return + stream = torch.cuda.current_stream() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + stream.record_event(start) + try: + yield + finally: + stream.record_event(end) + end.synchronize() + print("MSG".format(trace_name, name, start.elapsed_time(end))) +def verify(model, args, loss_fn=torch.sum, devices=None): + + + + + + if not isinstance(model, torch._C.CompiledFunction): + raise TypeError( + "MSG" + ) + is_module = isinstance(model, Module) if not isinstance(args, tuple): + args = (args,) saved_args = _clone_inputs(args) + if is_module: + saved_state = copy.deepcopy(model.state_dict()) def run_fwd_bwd(args, force_trace=False, assert_compiled=False): + params = list(model.parameters()) if is_module else [] + in_vars, _ = _flatten((args, params)) + compiled_fn = model + if force_trace: + compiled_fn.clear_cache() + if assert_compiled: + hits = compiled_fn.hits + out = model(*args) + if assert_compiled and compiled_fn.hits == hits: + raise RuntimeError("MSG") + if not isinstance(out, tuple): + out = (out,) + if loss_fn == torch.sum and len(out) != 1: + raise ValueError( + ( + "MSG" + "MSG" + ).format(len(out)) + ) + out_vars, _ = _flatten(out) + saved_outs = [ + v.detach().clone(memory_format=torch.preserve_format) for v in out_vars + ] + loss = loss_fn(*out) + grads = torch.autograd.grad([loss], in_vars) + saved_grads = [ + v.detach().clone(memory_format=torch.preserve_format) for v in grads + ] + return (saved_outs, saved_grads) with torch.random.fork_rng(devices, _caller="MSG"): + uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True) + assert model.has_trace_for(*args) if is_module: + model.load_state_dict(saved_state) + compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True) _verify_equal(uncompiled_outs, compiled_outs) + _verify_equal(uncompiled_grads, compiled_grads) +def _verify_equal(xs, ys): + for x, y in zip(xs, ys): + if x.sub(y).abs().max() > 1e-6: + raise RuntimeError("MSG") +def indent(s): + return "MSG".join(["MSG" + line for line in s.splitlines()]) +class TracingCheckError(Exception): + def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None): + self.message = "MSG" + if extra_msg is not None: + self.message += extra_msg + "MSG" + if graph_diff_error is not None: + self.message += "MSG" + self.message += indent(graph_diff_error) + "MSG" + if tensor_compare_error is not None: + self.message += ( + "MSG" + "MSG" + "MSG" + ) + self.message += indent(tensor_compare_error) + "MSG" + super(TracingCheckError, self).__init__(self.message) @torch.no_grad() +def _check_trace( + check_inputs, + func, + traced_func, + check_tolerance, + strict, + force_outplace, + is_trace_module, + _module_class, +): + + for inputs in check_inputs: if isinstance(inputs, torch.Tensor): + inputs = (inputs,) if is_trace_module: + copied_dict = {} + for name, data in inputs.items(): + copied_dict[name] = _clone_inputs(data) + check_mod = torch.jit.trace_module( + func.__self__ if hasattr(func, "MSG") else func, + copied_dict, + check_trace=False, + strict=strict, + _force_outplace=force_outplace, + _module_class=_module_class, + _compilation_unit=torch._C.CompilationUnit(), + ) + check_mod_func = check_mod._c._get_method(traced_func.name) + inputs = inputs[traced_func.name] + if isinstance(inputs, (torch.Tensor, dict)): + inputs = (inputs,) + else: + check_mod = torch.jit.trace( + func, + _clone_inputs(inputs), + check_trace=False, + strict=strict, + _force_outplace=force_outplace, + _module_class=_module_class, + ) + check_mod_func = check_mod def graph_diagnostic_info(): + mod_canonicalized = torch._C._jit_pass_canonicalize(traced_func.graph) + torch._C._jit_pass_inline(mod_canonicalized) + torch._C._jit_pass_erase_shape_information(mod_canonicalized) + mod_str = str(mod_canonicalized) + mod_str = re.sub(r"MSG", "MSG", mod_str) + check_canonicalized = torch._C._jit_pass_canonicalize(check_mod_func.graph) + torch._C._jit_pass_inline(check_canonicalized) + torch._C._jit_pass_erase_shape_information(check_canonicalized) + check_str = str(check_canonicalized) + check_str = re.sub(r"MSG", "MSG", check_str) graph_diff_errors = None + if mod_str != check_str: + import difflib graph_diff = difflib.ndiff( + mod_str.splitlines(True), check_str.splitlines(True) + ) + graph_diff_errors = "MSG" + indent("MSG".join(graph_diff)) + "MSG" for n_mod, n_check in zip( + mod_canonicalized.nodes(), check_canonicalized.nodes() + ): + if str(n_mod) != str(n_check): + graph_diff_errors += "MSG" + node_diff = difflib.ndiff( + str(n_mod).splitlines(True), str(n_check).splitlines(True) + ) + source_printout = ( + "MSG" + indent("MSG".join(node_diff)) + "MSG" + ) + mod_stack = n_mod.sourceRange() + if mod_stack: + source_printout += ( + "MSG" + indent(mod_stack) + "MSG" + ) + check_stack = n_check.sourceRange() + if check_stack: + source_printout += ( + "MSG" + indent(check_stack) + "MSG" + ) + graph_diff_errors += source_printout break tensor_compare_errors = None + for n_mod, n_check in zip( + mod_canonicalized.nodes(), check_canonicalized.nodes() + ): + if n_mod.kind() != n_check.kind(): + break if n_mod.kind() == "MSG" and not ( + n_mod.mustBeNone() or n_check.mustBeNone() + ): + if not n_mod.hasAttribute("MSG"): + continue + if n_mod.kindOf("MSG") != "MSG" or n_check.kindOf("MSG") != "MSG": + continue mod_tensor_val = n_mod.t("MSG") + check_tensor_val = n_check.t("MSG") try: + torch.testing.assert_allclose(mod_tensor_val, check_tensor_val) + except (RuntimeError, AssertionError) as e: + if tensor_compare_errors is None: + tensor_compare_errors = "MSG" + tensor_compare_errors += "MSG" + indent(str(n_mod)) + "MSG" + compare_stack = n_mod.sourceRange() + if compare_stack: + tensor_compare_errors += ( + "MSG" + indent(compare_stack) + "MSG" + ) + tensor_compare_errors += "MSG" + indent( + str(e) + ) break return graph_diff_errors, tensor_compare_errors def wrap_retval(x): + return x if isinstance(x, tuple) else (x,) def run_mod_and_filter_tensor_outputs(mod, inputs, running_what): + try: + outs = wrap_retval(mod(*_clone_inputs(inputs))) + outs = [out for out in outs if isinstance(out, torch.Tensor)] + return outs + except Exception as e: + graph_diff_errors, tensor_compare_errors = graph_diagnostic_info() + msg = f"MSG" + raise TracingCheckError( + graph_diff_errors, + tensor_compare_errors, + extra_msg=msg, + ) from e has_warned = [False] def maybe_warn_nondeterministic(): + if has_warned[0]: + return + has_warned[0] = True + nondeterm_ops = [ + op for op in traced_func.graph.nodes() if op.isNondeterministic() + ] + if len(nondeterm_ops) > 0: + nondeterministic_ops_warning = "MSG" + nondeterministic_ops_warning += ( + "MSG" + ) + nondeterministic_ops_warning += "MSG".join( + [indent(str(op)) for op in nondeterm_ops][:20] + ) + nondeterministic_ops_warning += ( + "MSG" + "MSG" + ) + warnings.warn( + nondeterministic_ops_warning, category=TracerWarning, stacklevel=5 + ) def compare_outputs(original, reference, match_what): + all_ok = True + for i, (orig, ref) in enumerate(zip(original, reference)): + try: + if orig.is_quantized: + orig = orig.dequantize() + if ref.is_quantized: + ref = ref.dequantize() + torch.testing.assert_allclose( + orig.double(), + ref.double(), + rtol=check_tolerance, + atol=torch.testing._get_default_tolerance(orig, ref)[1], + ) + except AssertionError as e: + maybe_warn_nondeterministic() + warnings.warn( + "MSG" + + str(i + 1) + + "MSG" + "MSG" + + match_what + + "MSG" + + str(e), + category=TracerWarning, + stacklevel=4, + ) + all_ok = False return all_ok traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "MSG") + fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, "MSG") + if compare_outputs(traced_outs, fn_outs, "MSG"): + check_outs = run_mod_and_filter_tensor_outputs( + check_mod_func, inputs, "MSG" + ) + compare_outputs(traced_outs, check_outs, "MSG") diag_info = graph_diagnostic_info() + if any(info is not None for info in diag_info): + raise TracingCheckError(*diag_info) +class TracerWarning(Warning): + @staticmethod + def ignore_lib_warnings(): + warnings.filterwarnings( + "MSG", category=TracerWarning, module="MSG" + ) +TracerWarning.ignore_lib_warnings() +torch._C._tracer_warn_use_python() +def make_tuple(example_inputs): + if isinstance(example_inputs, (torch.Tensor, dict)): + return (example_inputs,) + + if not isinstance(example_inputs, tuple): + return tuple(example_inputs) + return example_inputs +def make_module(mod, _module_class, _compilation_unit): + if isinstance(mod, ScriptModule): + return mod + elif torch._jit_internal.module_has_exports(mod): infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods + return torch.jit._recursive.create_script_module( + mod, + infer_methods_stubs_fn, + share_types=False + ) + else: + if _module_class is None: + _module_class = TopLevelTracedModule + return _module_class(mod, _compilation_unit=_compilation_unit) +def wrap_check_inputs(check_inputs): + if check_inputs is None: + return None return [{"MSG": c} for c in check_inputs] +def trace( + func, + example_inputs, + optimize=None, + check_trace=True, + check_inputs=None, + check_tolerance=1e-5, + strict=True, + _force_outplace=False, + _module_class=None, + _compilation_unit=_python_cu, +): + + if not _enabled: + return func + if optimize is not None: + warnings.warn( + "MSG" + ) if isinstance(func, torch.jit.ScriptModule): + warnings.warn( + "MSG" + ) + return func if isinstance(func, torch.nn.Module): + return trace_module( + func, + {"MSG": example_inputs}, + None, + check_trace, + wrap_check_inputs(check_inputs), + check_tolerance, + strict, + _force_outplace, + _module_class, + ) if ( + hasattr(func, "MSG") + and isinstance(func.__self__, torch.nn.Module) + and func.__name__ == "MSG" + ): + return trace_module( + func.__self__, + {"MSG": example_inputs}, + None, + check_trace, + wrap_check_inputs(check_inputs), + check_tolerance, + strict, + _force_outplace, + _module_class, + ) + if isinstance(example_inputs, (torch.Tensor, dict)): + example_inputs = (example_inputs,) + + elif not isinstance(example_inputs, tuple): + example_inputs = tuple(example_inputs) var_lookup_fn = _create_interpreter_name_lookup_fn(0) if hasattr(func, "MSG") and isinstance(func.__self__, torch.nn.Module): + raise AttributeError( + "MSG" + "MSG" + ) name = _qualified_name(func) + traced = torch._C._create_function_from_trace( + name, func, example_inputs, var_lookup_fn, strict, _force_outplace + ) + if check_trace: + if check_inputs is not None: + _check_trace( + check_inputs, + func, + traced, + check_tolerance, + strict, + _force_outplace, + False, + _module_class, + ) + else: + _check_trace( + [example_inputs], + func, + traced, + check_tolerance, + strict, + _force_outplace, + False, + _module_class, + ) return traced +_trace_module_map: Optional[Dict[Any, Any]] = None +def trace_module( + mod, + inputs, + optimize=None, + check_trace=True, + check_inputs=None, + check_tolerance=1e-5, + strict=True, + _force_outplace=False, + _module_class=None, + _compilation_unit=_python_cu, +): + + if not _enabled: + return mod + if optimize is not None: + warnings.warn( + "MSG" + ) var_lookup_fn = _create_interpreter_name_lookup_fn(0) if not isinstance(mod, torch.nn.Module): + raise AttributeError("MSG") if not isinstance(inputs, dict): + raise AttributeError("MSG") old_module_map = torch.jit._trace._trace_module_map + try: + trace_module_map: Dict[Any, Any] = {} def register_submods(mod, prefix): + for name, child in mod.named_children(): + submod_qualname = prefix + "MSG" + name + trace_module_map[child] = submod_qualname + register_submods(child, submod_qualname) trace_module_map["MSG"] = mod + torch.jit._trace._trace_module_map = trace_module_map + register_submods(mod, "MSG") module = make_module(mod, _module_class, _compilation_unit) for method_name, example_inputs in inputs.items(): + func = mod if method_name == "MSG" else getattr(mod, method_name) + example_inputs = make_tuple(example_inputs) + module._c._create_method_from_trace( + method_name, + func, + example_inputs, + var_lookup_fn, + strict, + _force_outplace, + ) + check_trace_method = module._c._get_method(method_name) if check_trace: + if check_inputs is not None: + _check_trace( + check_inputs, + func, + check_trace_method, + check_tolerance, + strict, + _force_outplace, + True, + _module_class, + ) + else: + _check_trace( + [inputs], + func, + check_trace_method, + check_tolerance, + strict, + _force_outplace, + True, + _module_class, + ) + finally: + torch.jit._trace._trace_module_map = old_module_map return module +def is_tracing(): + + return torch._C._is_tracing() +class TracedModule(ScriptModule): + _disable_script_meta = True def __init__(self, orig, id_set=None, _compilation_unit=None): + super(TracedModule, self).__init__() + assert isinstance(orig, torch.nn.Module) id_set = set() class QualnameWrapper(torch.nn.Module): + pass QualnameWrapper._jit_override_qualname = torch._jit_internal._qualified_name( + type(orig) + ) tmp_module = QualnameWrapper() def check_unique(param): + if param in id_set: + raise ValueError( + "MSG" + ) + id_set.add(param) tmp_module.training = orig.training for name, param in orig._parameters.items(): + if param is not None: + tmp_module._parameters[name] = param + check_unique(param) + for name, buf in orig._buffers.items(): + if buf is not None: + tmp_module._buffers[name] = buf + check_unique(buf) + for name, val in orig.__dict__.items(): + if ( + torch._C._jit_is_script_object(val) + and name not in orig._parameters + and name not in orig._buffers + ): + setattr(tmp_module, name, val) if orig._backward_hooks: + raise ValueError( + "MSG" + + str(orig) + ) for name, submodule in orig._modules.items(): + tmp_module._modules[name] = make_module( + submodule, TracedModule, _compilation_unit=None + ) script_module = torch.jit._recursive.create_script_module( + tmp_module, lambda module: (), share_types=False + ) self.__dict__["MSG"] = type(orig).__name__ + self.__dict__["MSG"] = script_module + for name in ("MSG", "MSG", "MSG"): + delattr(self, name) def forward(self, *args, **kwargs): + raise RuntimeError("MSG") def __getattr__(self, attr): + if "MSG" not in self.__dict__: + return super(TracedModule, self).__getattr__(attr) + return getattr(self._actual_script_module, attr) def __setattr__(self, attr, value): + if "MSG" not in self.__dict__: + return super(TracedModule, self).__setattr__(attr, value) + setattr(self._actual_script_module, attr, value) def _get_name(self): + return self._name def extra_repr(self): + return "MSG".format(self._name) +class TopLevelTracedModule(TracedModule): + forward = _CachedForward() def _reconstruct(self, cpp_module): + self.__dict__["MSG"]._reconstruct(cpp_module) +def _script_if_tracing(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if not is_tracing(): + return fn(*args, **kwargs) compiled_fn = script(wrapper.__original_fn) + return compiled_fn(*args, **kwargs) wrapper.__original_fn = fn + wrapper.__script_if_tracing_wrapper = True return wrapper +def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False, + return_inputs=False, _return_inputs_states=False): + + if kwargs is None: + kwargs = {} + if not isinstance(args, tuple): + args = (args,) + outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) + return outs +import torch._C from torch.utils import set_module +from torch._jit_internal import ( + Final, + Future, + _overload, + _overload_method, + ignore, + is_scripting, + export, + unused, +) +from torch.jit._script import ( + script, + Attribute, + ScriptModule, + script_method, + RecursiveScriptModule, + ScriptWarning, + interface, + CompilationUnit, + ScriptFunction, + _unwrap_optional, +) +from torch.jit._trace import ( + trace, + trace_module, + TracedModule, + TracerWarning, + TracingCheckError, + is_tracing, + ONNXTracedModule, + TopLevelTracedModule, + _unique_state_dict, + _flatten, + _script_if_tracing, + _get_trace_graph, +) +from torch.jit._async import fork, wait +from torch.jit._serialization import save, load +from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph from torch.jit._freeze import freeze +_fork = fork +_wait = wait +def export_opnames(m): + r + return torch._C._export_opnames(m._c) Error = torch._C.JITException +set_module(Error, "MSG") Error.__name__ = "MSG" +Error.__qualname__ = "MSG" +def annotate(the_type, the_value): + + return the_value +if not torch._C._jit_init(): + raise RuntimeError("MSG") +import torch from torch.jit._serialization import validate_map_location import pathlib +import os def _load_for_lite_interpreter(f, map_location=None): + r + if isinstance(f, str): + if not os.path.exists(f): + raise ValueError("MSG".format(f)) + if os.path.isdir(f): + raise ValueError("MSG".format(f)) map_location = validate_map_location(map_location) if isinstance(f, str) or isinstance(f, pathlib.Path): + cpp_module = torch._C._load_for_lite_interpreter(f, map_location) + else: + cpp_module = torch._C._load_for_lite_interpreter_from_buffer(f.read(), map_location) return LiteScriptModule(cpp_module) +class LiteScriptModule(object): + def __init__(self, cpp_module): + self._c = cpp_module + super(LiteScriptModule, self).__init__() def __call__(self, *input): + return self._c.forward(input) def find_method(self, method_name): + return self._c.find_method(method_name) def forward(self, *input): + return self._c.forward(input) def run_method(self, method_name, *input): + return self._c.run_method(method_name, input) +import sys import torch +from torch._C import _add_docstr, _linalg Tensor = torch.Tensor +det = _add_docstr(_linalg.linalg_det, r) norm = _add_docstr(_linalg.linalg_norm, r) +import multiprocessing +import multiprocessing.pool +import multiprocessing.util as util from .queue import SimpleQueue +def clean_worker(*args, **kwargs): + import gc + multiprocessing.pool.worker(*args, **kwargs) + + + + gc.collect() +class Pool(multiprocessing.pool.Pool): + def _setup_queues(self): + self._inqueue = SimpleQueue() + self._outqueue = SimpleQueue() + self._quick_put = self._inqueue._writer.send + self._quick_get = self._outqueue._reader.recv def _repopulate_pool(self): + for i in range(self._processes - len(self._pool)): + args = (self._inqueue, self._outqueue, + self._initializer, + self._initargs, self._maxtasksperchild) + if hasattr(self, 'MSG'): + args += (self._wrap_exception,) + w = self.Process(target=clean_worker, args=args) + self._pool.append(w) + w.name = w.name.replace('MSG', 'MSG') + w.daemon = True + w.start() + util.debug('MSG') +import io +import multiprocessing +import multiprocessing.queues +from multiprocessing.reduction import ForkingPickler +import pickle +class ConnectionWrapper(object): + def __init__(self, conn): + self.conn = conn def send(self, obj): + buf = io.BytesIO() + ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj) + self.send_bytes(buf.getvalue()) def recv(self): + buf = self.recv_bytes() + return pickle.loads(buf) def __getattr__(self, name): + if 'MSG' in self.__dict__: + return getattr(self.conn, name) + raise AttributeError("MSG".format( + type(self).__name__, 'MSG')) +class Queue(multiprocessing.queues.Queue): def __init__(self, *args, **kwargs): + super(Queue, self).__init__(*args, **kwargs) + self._reader = ConnectionWrapper(self._reader) + self._writer = ConnectionWrapper(self._writer) + self._send = self._writer.send + self._recv = self._reader.recv +class SimpleQueue(multiprocessing.queues.SimpleQueue): def _make_methods(self): + if not isinstance(self._reader, ConnectionWrapper): + self._reader = ConnectionWrapper(self._reader) + self._writer = ConnectionWrapper(self._writer) + super(SimpleQueue, self)._make_methods() +import torch +import torch.utils.hooks +from torch._namedtensor_internals import check_serializing_named_tensor +import os +import threading +import multiprocessing +from multiprocessing.util import register_after_fork +from multiprocessing.reduction import ForkingPickler +try: + + + + + + import multiprocessing.resource_sharer +except ImportError: + pass +class StorageWeakRef(object): + r def __init__(self, storage): + self.cdata = storage._weak_ref() + self._free_weak_ref = torch.Storage._free_weak_ref def expired(self): + return torch.Storage._expired(self.cdata) def __del__(self): + self._free_weak_ref(self.cdata) +class SharedCache(dict): + def __init__(self): + self.limit = 128 + self._after_fork() + register_after_fork(self, SharedCache._after_fork) def _after_fork(self): + self.lock = threading.Lock() def __setitem__(self, key, storage_ref): + dict.__setitem__(self, key, storage_ref) + if len(self) > self.limit: + self.free_dead_references() def free_dead_references(self): + with self.lock: + live = 0 + for key, storage_ref in list(self.items()): + if storage_ref.expired(): + del self[key] + else: + live += 1 + self.limit = max(128, live * 2) shared_cache = SharedCache() +def rebuild_event(device, handle): + return torch.cuda.Event.from_ipc_handle(device, handle) +def reduce_event(event): + handle = event.ipc_handle() + return (rebuild_event, (event.device, handle)) +def rebuild_tensor(cls, storage, metadata): + storage_offset, size, stride, requires_grad = metadata + t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) + if cls == torch.nn.parameter.Parameter: + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad + return t +def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset, + storage_cls, storage_device, storage_handle, storage_size_bytes, storage_offset_bytes, + requires_grad, ref_counter_handle, ref_counter_offset, event_handle, event_sync_required): + + if storage_handle is None or storage_size_bytes == 0: + storage = storage_cls(0) + else: + storage = storage_from_cache(storage_cls, (storage_handle, storage_offset_bytes)) + if storage is None: + torch.cuda._lazy_init() + storage = storage_cls._new_shared_cuda( + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required) + shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage) + else: + storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset) t = torch._utils._rebuild_tensor(storage, tensor_offset, tensor_size, tensor_stride) + if tensor_cls == torch.nn.parameter.Parameter: + t = torch.nn.parameter.Parameter(t) + t.requires_grad = requires_grad + return t +def reduce_tensor(tensor): + storage = tensor.storage() if tensor.requires_grad and not tensor.is_leaf: + raise RuntimeError("MSG" + "MSG" + "MSG" + "MSG") check_serializing_named_tensor(tensor) + torch.utils.hooks.warn_if_has_hooks(tensor) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + if storage.is_cuda: + (device, + handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required) = storage._share_cuda_() + tensor_offset = tensor.storage_offset() + shared_cache[handle] = StorageWeakRef(storage) + return (rebuild_cuda_tensor, + (type(tensor), + tensor.size(), + tensor.stride(), + tensor_offset, + type(storage), + device, + handle, + storage_size_bytes, + storage_offset_bytes, + tensor.requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required)) + metadata = (tensor.storage_offset(), tensor.size(), tensor.stride(), tensor.requires_grad) + return (rebuild_tensor, (type(tensor), storage, metadata)) +def fd_id(fd): + + + + stat = os.fstat(fd) + return (stat.st_ino, stat.st_dev) +def storage_from_cache(cls, key): + storage_ref = shared_cache.get(key) + if storage_ref is None: + return None + return cls._new_with_weak_ptr(storage_ref.cdata) +def rebuild_storage_fd(cls, df, size): + fd = df.detach() + try: + storage = storage_from_cache(cls, fd_id(fd)) + if storage is not None: + return storage + storage = cls._new_shared_fd(fd, size) + shared_cache[fd_id(fd)] = StorageWeakRef(storage) + return storage + finally: + os.close(fd) +def rebuild_storage_filename(cls, manager, handle, size): + storage = storage_from_cache(cls, handle) + if storage is not None: + return storage._shared_decref() + storage = cls._new_shared_filename(manager, handle, size) + shared_cache[handle] = StorageWeakRef(storage) + return storage._shared_decref() +def rebuild_storage_empty(cls): + return cls() +def reduce_storage(storage): + from . import get_sharing_strategy + if storage.is_cuda: + raise RuntimeError("MSG") + elif get_sharing_strategy() == 'MSG': + metadata = storage._share_filename_() + cache_key = metadata[1] + rebuild = rebuild_storage_filename + storage._shared_incref() + elif storage.size() == 0: + return (rebuild_storage_empty, (type(storage),)) + else: + fd, size = storage._share_fd_() + df = multiprocessing.reduction.DupFd(fd) + cache_key = fd_id(fd) + metadata = (df, size) + rebuild = rebuild_storage_fd shared_cache[cache_key] = StorageWeakRef(storage) + return (rebuild, (type(storage),) + metadata) +def init_reductions(): + ForkingPickler.register(torch.cuda.Event, reduce_event) for t in torch._storage_classes: + ForkingPickler.register(t, reduce_storage) for t in torch._tensor_classes: + ForkingPickler.register(t, reduce_tensor) + ForkingPickler.register(torch.Tensor, reduce_tensor) + ForkingPickler.register(torch.nn.parameter.Parameter, reduce_tensor) import multiprocessing +import multiprocessing.connection +import signal +import sys +import warnings from . import _prctl_pr_set_pdeathsig +def _wrap(fn, i, args, error_queue): + + + + + _prctl_pr_set_pdeathsig(signal.SIGINT) try: + fn(i, *args) + except KeyboardInterrupt: + pass + except Exception: + import traceback + error_queue.put(traceback.format_exc()) + sys.exit(1) _supports_context = sys.version_info >= (3, 4) +def _python_version_check(): + if not _supports_context: + raise RuntimeError("MSG" + "MSG" + "MSG" + "MSG" + "MSG" + "MSG" + "MSG") +class ProcessContext: + def __init__(self, processes, error_queues): + _python_version_check() + self.error_queues = error_queues + self.processes = processes + self.sentinels = { + process.sentinel: index + for index, process in enumerate(processes) + } def pids(self): + return [int(process.pid) for process in self.processes] def join(self, timeout=None): + r + if len(self.sentinels) == 0: + return True ready = multiprocessing.connection.wait( + self.sentinels.keys(), + timeout=timeout, + ) error_index = None + for sentinel in ready: + index = self.sentinels.pop(sentinel) + process = self.processes[index] + process.join() + if process.exitcode != 0: + error_index = index + break if error_index is None: + return len(self.sentinels) == 0 for process in self.processes: + if process.is_alive(): + process.terminate() + process.join() if self.error_queues[error_index].empty(): + exitcode = self.processes[error_index].exitcode + if exitcode < 0: + name = signal.Signals(-exitcode).name + raise Exception( + "MSG" % + (error_index, name) + ) + else: + raise Exception( + "MSG" % + (error_index, exitcode) + ) original_trace = self.error_queues[error_index].get() + msg = "MSG" % error_index + msg += original_trace + raise Exception(msg) +class SpawnContext(ProcessContext): + def __init__(self, processes, error_queues): + warnings.warn('MSG') + super(SpawnContext, self).__init__(self, processes, error_queues) + pass +def start_processes(fn, args=(), nprocs=1, join=True, daemon=False, start_method='MSG'): + _python_version_check() + mp = multiprocessing.get_context(start_method) + error_queues = [] + processes = [] + for i in range(nprocs): + error_queue = mp.SimpleQueue() + process = mp.Process( + target=_wrap, + args=(fn, i, args, error_queue), + daemon=daemon, + ) + process.start() + error_queues.append(error_queue) + processes.append(process) context = ProcessContext(processes, error_queues) + if not join: + return context + while not context.join(): + pass +def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='MSG'): + r + if start_method != 'MSG': + msg = ('MSG' + 'MSG' + 'MSG' % start_method) + warnings.warn(msg) + return start_processes(fn, args, nprocs, join, daemon, start_method='MSG') +import sys __all__ = ['MSG'] if sys.platform == 'MSG' or sys.version_info < (3, 7): + import multiprocessing.util as _util def _register(func): + def wrapper(arg): + func() + _util.register_after_fork(_register, wrapper) +else: + import os def _register(func): + os.register_at_fork(after_in_child=func) def register_after_fork(func): + + _register(func) import torch +import sys +from .reductions import init_reductions +import multiprocessing __all__ = ['MSG', 'MSG', + 'MSG'] +from multiprocessing import * +__all__ += multiprocessing.__all__ +torch._C._multiprocessing_init() +if sys.version_info < (3, 3): + + from .queue import Queue, SimpleQueue + from .pool import Pool from .spawn import spawn, SpawnContext, _supports_context, start_processes, ProcessContext +if sys.platform == 'MSG' or sys.platform == 'MSG': + _sharing_strategy = 'MSG' + _all_sharing_strategies = {'MSG'} +else: + _sharing_strategy = 'MSG' + _all_sharing_strategies = {'MSG', 'MSG'} +def set_sharing_strategy(new_strategy): + + global _sharing_strategy + assert new_strategy in _all_sharing_strategies + _sharing_strategy = new_strategy +def get_sharing_strategy(): + + return _sharing_strategy +def get_all_sharing_strategies(): + + return _all_sharing_strategies +init_reductions() +from typing import TypeVar, Union, Tuple +from .. import Tensor +T = TypeVar('MSG') +_scalar_or_tuple_any_t = Union[T, Tuple[T, ...]] +_scalar_or_tuple_1_t = Union[T, Tuple[T]] +_scalar_or_tuple_2_t = Union[T, Tuple[T, T]] +_scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]] +_scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]] +_scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]] +_scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]] +_size_any_t = _scalar_or_tuple_any_t[int] +_size_1_t = _scalar_or_tuple_1_t[int] +_size_2_t = _scalar_or_tuple_2_t[int] +_size_3_t = _scalar_or_tuple_3_t[int] +_size_4_t = _scalar_or_tuple_4_t[int] +_size_5_t = _scalar_or_tuple_5_t[int] +_size_6_t = _scalar_or_tuple_6_t[int] +_ratio_2_t = _scalar_or_tuple_2_t[float] +_ratio_3_t = _scalar_or_tuple_3_t[float] +_ratio_any_t = _scalar_or_tuple_any_t[float] _tensor_list_t = _scalar_or_tuple_any_t[Tensor] +_maybe_indices_t = _scalar_or_tuple_2_t[Tensor] +from torch import nn +class OrderedDictWrapper(object): + def __init__(self, cpp_module, attr): + self.cpp_module = cpp_module + self.attr = attr @property + def cpp_dict(self): + return getattr(self.cpp_module, self.attr) + def items(self): + return self.cpp_dict.items() def keys(self): + return self.cpp_dict.keys() def values(self): + return self.cpp_dict.values() def __iter__(self): + return self.cpp_dict.__iter__() def __len__(self): + return self.cpp_dict.__len__() def __contains__(self, key): + return self.cpp_dict.__contains__(key) def __getitem__(self, key): + return self.cpp_dict.__getitem__(key) +class ModuleWrapper(nn.Module): + def __init__(self, cpp_module): + self.cpp_module = cpp_module + super(ModuleWrapper, self).__init__() + self._parameters = OrderedDictWrapper(cpp_module, "MSG") + self._buffers = OrderedDictWrapper(cpp_module, "MSG") + self._modules = OrderedDictWrapper(cpp_module, "MSG") + for attr in dir(cpp_module): + if not attr.startswith("MSG"): + setattr(self, attr, getattr(self.cpp_module, attr)) def _apply(self, fn): + for param in self.parameters(): + param.data = fn(param.data) + if param._grad is not None: + param._grad.data = fn(param._grad.data) for buf in self.buffers(): + buf.data = fn(buf.data) return self @property + def training(self): + return self.cpp_module.training @training.setter + def training(self, mode): + self.cpp_module.train(mode) def __repr__(self): + return self.cpp_module.__repr__() +r +import warnings +import math import torch +from torch._C import _infer_size, _add_docstr +from . import _reduction as _Reduction +from .modules import utils +from .modules.utils import _single, _pair, _triple, _list_with_default +from . import grad +from torch import _VF +from .._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple +from ..overrides import has_torch_function, handle_torch_function +Tensor = torch.Tensor conv1d = _add_docstr(torch.conv1d, r) conv2d = _add_docstr(torch.conv2d, r) conv3d = _add_docstr(torch.conv3d, r) conv_transpose1d = _add_docstr(torch.conv_transpose1d, r) conv_transpose2d = _add_docstr(torch.conv_transpose2d, r) conv_transpose3d = _add_docstr(torch.conv_transpose3d, r) conv_tbc = _add_docstr(torch.conv_tbc, r) avg_pool1d = _add_docstr(torch.avg_pool1d, r) +avg_pool2d = _add_docstr(torch._C._nn.avg_pool2d, r) avg_pool3d = _add_docstr(torch._C._nn.avg_pool3d, r) +def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None, + output_ratio=None, return_indices=False, + _random_samples=None): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + fractional_max_pool2d_with_indices, (input,), input, kernel_size, + output_size=output_size, output_ratio=output_ratio, + return_indices=return_indices, _random_samples=_random_samples) + if output_size is None and output_ratio is None: + raise ValueError("MSG" + "MSG") + if output_size is None: + assert output_ratio is not None + _output_ratio = _pair(output_ratio) + output_size = [int(input.size(2) * _output_ratio[0]), + int(input.size(3) * _output_ratio[1])] if _random_samples is None: + _random_samples = torch.rand(input.size(0), input.size(1), 2, dtype=input.dtype, device=input.device) + return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples) +def _fractional_max_pool2d(input, kernel_size, output_size=None, + output_ratio=None, return_indices=False, + _random_samples=None): + + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + fractional_max_pool2d, (input,), input, kernel_size, + output_size=output_size, output_ratio=output_ratio, + return_indices=return_indices, _random_samples=_random_samples) + return fractional_max_pool2d_with_indices(input, kernel_size, output_size, + output_ratio, return_indices, + _random_samples)[0] fractional_max_pool2d = boolean_dispatch( + arg_name='MSG', + arg_index=4, + default=False, + if_true=fractional_max_pool2d_with_indices, + if_false=_fractional_max_pool2d, + module_name=__name__, + func_name='MSG') +def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None, + output_ratio=None, return_indices=False, + _random_samples=None): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + fractional_max_pool3d_with_indices, (input,), input, kernel_size, + output_size=output_size, output_ratio=output_ratio, + return_indices=return_indices, _random_samples=_random_samples) + if output_size is None and output_ratio is None: + raise ValueError("MSG" + "MSG") + if output_size is None: + assert output_ratio is not None + _output_ratio = _triple(output_ratio) + output_size = [int(input.size(2) * _output_ratio[0]), + int(input.size(3) * _output_ratio[1]), + int(input.size(4) * _output_ratio[2])] if _random_samples is None: + _random_samples = torch.rand(input.size(0), input.size(1), 3, dtype=input.dtype, device=input.device) + return torch._C._nn.fractional_max_pool3d(input, kernel_size, output_size, _random_samples) +def _fractional_max_pool3d(input, kernel_size, output_size=None, + output_ratio=None, return_indices=False, + _random_samples=None): + + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + fractional_max_pool3d, (input,), input, kernel_size, + output_size=output_size, output_ratio=output_ratio, + return_indices=return_indices, _random_samples=_random_samples) + return fractional_max_pool3d_with_indices(input, kernel_size, output_size, + output_ratio, return_indices, + _random_samples)[0] fractional_max_pool3d = boolean_dispatch( + arg_name='MSG', + arg_index=4, + default=False, + if_true=fractional_max_pool3d_with_indices, + if_false=_fractional_max_pool3d, + module_name=__name__, + func_name='MSG') +def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0, + dilation=1, ceil_mode=False, return_indices=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + max_pool1d_with_indices, (input,), input, kernel_size, + stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, + return_indices=return_indices) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch.max_pool1d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode) +def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, + ceil_mode=False, return_indices=False): + + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + max_pool1d, (input,), input, kernel_size, + stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, + return_indices=return_indices) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch.max_pool1d( + input, kernel_size, stride, padding, dilation, ceil_mode) max_pool1d = boolean_dispatch( + arg_name='MSG', + arg_index=6, + default=False, + if_true=max_pool1d_with_indices, + if_false=_max_pool1d, + module_name=__name__, + func_name='MSG') +def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1, + ceil_mode=False, return_indices=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + max_pool2d_with_indices, (input,), input, kernel_size, + stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, + return_indices=return_indices) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) +def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, + ceil_mode=False, return_indices=False): + + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + max_pool2d, (input,), input, kernel_size, + stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, + return_indices=return_indices) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch.max_pool2d( + input, kernel_size, stride, padding, dilation, ceil_mode) max_pool2d = boolean_dispatch( + arg_name='MSG', + arg_index=6, + default=False, + if_true=max_pool2d_with_indices, + if_false=_max_pool2d, + module_name=__name__, + func_name='MSG') +def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0, + dilation=1, ceil_mode=False, return_indices=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + max_pool3d_with_indices, (input,), input, kernel_size, + stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, + return_indices=return_indices) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch._C._nn.max_pool3d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode) +def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, + ceil_mode=False, return_indices=False): + + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + max_pool3d, (input,), input, kernel_size, stride=stride, padding=padding, + dilation=dilation, ceil_mode=ceil_mode, return_indices=return_indices) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch.max_pool3d( + input, kernel_size, stride, padding, dilation, ceil_mode) max_pool3d = boolean_dispatch( + arg_name='MSG', + arg_index=6, + default=False, + if_true=max_pool3d_with_indices, + if_false=_max_pool3d, + module_name=__name__, + func_name='MSG') +def _unpool_output_size(input, kernel_size, stride, padding, output_size): + + input_size = input.size() + default_size = torch.jit.annotate(List[int], []) + for d in range(len(kernel_size)): + default_size.append((input_size[d + 2] - 1) * stride[d] + + kernel_size[d] - 2 * padding[d]) + if output_size is None: + ret = default_size + else: + if len(output_size) == len(kernel_size) + 2: + output_size = output_size[2:] + if len(output_size) != len(kernel_size): + raise ValueError("MSG" + "MSG" + .format(len(kernel_size), len(kernel_size) + 2, + len(output_size))) + for d in range(len(kernel_size)): + min_size = default_size[d] - stride[d] + max_size = default_size[d] + stride[d] + if not (min_size < output_size[d] < max_size): + raise ValueError( + 'MSG' + .format(output_size, d, min_size, max_size)) ret = output_size + return ret +def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, + output_size=None): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + max_unpool1d, (input,), input, indices, kernel_size, + stride=stride, padding=padding, output_size=output_size) + kernel_size = _single(kernel_size) + if stride is not None: + _stride = _single(stride) + else: + _stride = kernel_size + padding = _single(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, + output_size) + if isinstance(output_size, list): + output_size = output_size + [1] + else: + output_size = output_size + (1,) + return torch._C._nn.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), + output_size).squeeze(3) +def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, + output_size=None): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + max_unpool2d, (input,), input, indices, kernel_size, + stride=stride, padding=padding, output_size=output_size) + kernel_size = _pair(kernel_size) + if stride is not None: + _stride = _pair(stride) + else: + _stride = kernel_size + padding = _pair(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, + output_size) + return torch._C._nn.max_unpool2d(input, indices, output_size) +def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, + output_size=None): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + max_unpool3d, (input,), input, indices, kernel_size, + stride=stride, padding=padding, output_size=output_size) + kernel_size = _triple(kernel_size) + if stride is not None: + _stride = _triple(stride) + else: + _stride = kernel_size + padding = _triple(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, + output_size) + return torch._C._nn.max_unpool3d( + input, indices, output_size, _stride, padding) +def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, + ceil_mode=ceil_mode) + kw, kh = utils._pair(kernel_size) + if stride is not None: + out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) + else: + out = avg_pool2d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode) return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1. / norm_type) +def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, + ceil_mode=ceil_mode) + if stride is not None: + out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) + else: + out = avg_pool1d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode) return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type) +def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + adaptive_max_pool1d_with_indices, (input,), input, output_size, + return_indices=return_indices) + return torch.adaptive_max_pool1d(input, output_size) +def _adaptive_max_pool1d(input, output_size, return_indices=False): + + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + adaptive_max_pool1d, (input,), input, output_size, + return_indices=return_indices) + return adaptive_max_pool1d_with_indices(input, output_size)[0] adaptive_max_pool1d = boolean_dispatch( + arg_name='MSG', + arg_index=2, + default=False, + if_true=adaptive_max_pool1d_with_indices, + if_false=_adaptive_max_pool1d, + module_name=__name__, + func_name='MSG') +def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + adaptive_max_pool2d_with_indices, (input,), input, output_size, + return_indices=return_indices) + output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_max_pool2d(input, output_size) +def _adaptive_max_pool2d(input, output_size, return_indices=False): + + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + adaptive_max_pool2d, (input,), input, output_size, + return_indices=return_indices) + return adaptive_max_pool2d_with_indices(input, output_size)[0] adaptive_max_pool2d = boolean_dispatch( + arg_name='MSG', + arg_index=2, + default=False, + if_true=adaptive_max_pool2d_with_indices, + if_false=_adaptive_max_pool2d, + module_name=__name__, + func_name='MSG') +def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + adaptive_max_pool3d_with_indices, (input,), input, output_size, + return_indices=return_indices) + output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_max_pool3d(input, output_size) +def _adaptive_max_pool3d(input, output_size, return_indices=False): + + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + adaptive_max_pool3d, (input,), input, output_size, + return_indices=return_indices) + return adaptive_max_pool3d_with_indices(input, output_size)[0] adaptive_max_pool3d = boolean_dispatch( + arg_name='MSG', + arg_index=2, + default=False, + if_true=adaptive_max_pool3d_with_indices, + if_false=_adaptive_max_pool3d, + module_name=__name__, + func_name='MSG') +adaptive_avg_pool1d = _add_docstr(torch.adaptive_avg_pool1d, r) +def adaptive_avg_pool2d(input, output_size): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + adaptive_avg_pool2d, (input,), input, output_size) + _output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_avg_pool2d(input, _output_size) +def adaptive_avg_pool3d(input, output_size): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + adaptive_avg_pool3d, (input,), input, output_size) + _output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_avg_pool3d(input, _output_size) def dropout(input, p=0.5, training=True, inplace=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + dropout, (input,), input, p=p, training=training, inplace=inplace) + if p < 0. or p > 1.: + raise ValueError("MSG" + "MSG".format(p)) + return (_VF.dropout_(input, p, training) + if inplace + else _VF.dropout(input, p, training)) +def alpha_dropout(input, p=0.5, training=False, inplace=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + alpha_dropout, (input,), input, p=p, training=training, inplace=inplace) + if p < 0. or p > 1.: + raise ValueError("MSG" + "MSG".format(p)) + return (_VF.alpha_dropout_(input, p, training) + if inplace + else _VF.alpha_dropout(input, p, training)) +def dropout2d(input, p=0.5, training=True, inplace=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + dropout2d, (input,), input, p=p, training=training, inplace=inplace) + if p < 0. or p > 1.: + raise ValueError("MSG" + "MSG".format(p)) + return (_VF.feature_dropout_(input, p, training) + if inplace + else _VF.feature_dropout(input, p, training)) +def dropout3d(input, p=0.5, training=True, inplace=False): + + r + + + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + dropout3d, (input,), input, p=p, training=training, inplace=inplace) + if p < 0. or p > 1.: + raise ValueError("MSG" + "MSG".format(p)) + return (_VF.feature_dropout_(input, p, training) + if inplace + else _VF.feature_dropout(input, p, training)) +def feature_alpha_dropout(input, p=0.5, training=False, inplace=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + feature_alpha_dropout, (input,), input, p=p, training=training, + inplace=inplace) + if p < 0. or p > 1.: + raise ValueError("MSG" + "MSG".format(p)) + return (_VF.feature_alpha_dropout_(input, p, training) + if inplace + else _VF.feature_alpha_dropout(input, p, training)) +def _threshold(input, threshold, value, inplace=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + _threshold, (input,), input, threshold, value, inplace=inplace) + if inplace: + result = _VF.threshold_(input, threshold, value) + else: + result = _VF.threshold(input, threshold, value) + return result +threshold = _threshold threshold_ = _add_docstr(_VF.threshold_, r) +def relu(input: Tensor, inplace: bool = False) -> Tensor: + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function(relu, (input,), input, inplace=inplace) + if inplace: + result = torch.relu_(input) + else: + result = torch.relu(input) + return result +relu_ = _add_docstr(torch.relu_, r) +def glu(input: Tensor, dim: int = -1) -> Tensor: + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function(glu, (input,), input, dim=dim) + if input.dim() == 0: + raise RuntimeError("MSG") + return torch._C._nn.glu(input, dim) +def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: bool = False) -> Tensor: + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + hardtanh, (input,), input, min_val=min_val, max_val=max_val, + inplace=inplace) + if inplace: + result = torch._C._nn.hardtanh_(input, min_val, max_val) + else: + result = torch._C._nn.hardtanh(input, min_val, max_val) + return result +hardtanh_ = _add_docstr(torch._C._nn.hardtanh_, r) +def relu6(input, inplace=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function(relu6, (input,), input, inplace=inplace) + return hardtanh(input, 0., 6., inplace) +def elu(input, alpha=1., inplace=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function(elu, (input,), input, alpha=alpha, + inplace=inplace) + if inplace: + result = torch._C._nn.elu_(input, alpha) + else: + result = torch._C._nn.elu(input, alpha) + return result +elu_ = _add_docstr(torch._C._nn.elu_, r) +def selu(input, inplace=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function(selu, (input,), input, inplace=inplace) + if inplace: + result = torch.selu_(input) + else: + result = torch.selu(input) + return result +selu_ = _add_docstr(torch.selu_, r) +def celu(input, alpha=1., inplace=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function(celu, (input,), input, alpha=alpha, + inplace=inplace) + if inplace: + result = torch.celu_(input, alpha) + else: + result = torch.celu(input, alpha) + return result celu_ = _add_docstr(torch.celu_, r) +def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False) -> Tensor: + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + leaky_relu, (input,), input, negative_slope=negative_slope, + inplace=inplace) + if inplace: + result = torch._C._nn.leaky_relu_(input, negative_slope) + else: + result = torch._C._nn.leaky_relu(input, negative_slope) + return result +leaky_relu_ = _add_docstr(torch._C._nn.leaky_relu_, r) +def prelu(input, weight): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function(prelu, (input,), input, weight) + return torch.prelu(input, weight) +def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + rrelu, (input,), input, lower=lower, upper=upper, + training=training, inplace=inplace) + if inplace: + result = torch.rrelu_(input, lower, upper, training) + else: + result = torch.rrelu(input, lower, upper, training) + return result +rrelu_ = _add_docstr(torch.rrelu_, r) logsigmoid = _add_docstr(torch._C._nn.log_sigmoid, r) def gelu(input): + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function(gelu, (input,), input) + return torch._C._nn.gelu(input) +def hardshrink(input, lambd=0.5): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function(hardshrink, (input,), input, lambd=lambd) + return torch.hardshrink(input, lambd) +def tanhshrink(input): + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function(tanhshrink, (input,), input) + return input - input.tanh() +def softsign(input): + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function(softsign, (input,), input) + return input / (input.abs() + 1) +softplus = _add_docstr(torch._C._nn.softplus, r) +def _get_softmax_dim(name, ndim, stacklevel): + + warnings.warn("MSG" + "MSG".format(name), stacklevel=stacklevel) + if ndim == 0 or ndim == 1 or ndim == 3: + ret = 0 + else: + ret = 1 + return ret +def softmin(input, dim=None, _stacklevel=3, dtype=None): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if dim is None: + dim = _get_softmax_dim('MSG', input.dim(), _stacklevel) + if dtype is None: + ret = (-input).softmax(dim) + else: + ret = (-input).softmax(dim, dtype=dtype) + return ret +def softmax(input, dim=None, _stacklevel=3, dtype=None): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if dim is None: + dim = _get_softmax_dim('MSG', input.dim(), _stacklevel) + if dtype is None: + ret = input.softmax(dim) + else: + ret = input.softmax(dim, dtype=dtype) + return ret +def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): + + r + if not torch.jit.is_scripting(): + if type(logits) is not Tensor and has_torch_function((logits,)): + return handle_torch_function( + gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim) + if eps != 1e-10: + warnings.warn("MSG") gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() + gumbels = (logits + gumbels) / tau + y_soft = gumbels.softmax(dim) if hard: + index = y_soft.max(dim, keepdim=True)[1] + y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) + ret = y_hard - y_soft.detach() + y_soft + else: + ret = y_soft + return ret +def log_softmax(input, dim=None, _stacklevel=3, dtype=None): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if dim is None: + dim = _get_softmax_dim('MSG', input.dim(), _stacklevel) + if dtype is None: + ret = input.log_softmax(dim) + else: + ret = input.log_softmax(dim, dtype=dtype) + return ret +softshrink = _add_docstr(torch._C._nn.softshrink, r) +def tanh(input): + r + warnings.warn("MSG") + return input.tanh() +def sigmoid(input): + r + warnings.warn("MSG") + return input.sigmoid() +def hardsigmoid(input, inplace=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.hardsigmoid_(input) + return torch._C._nn.hardsigmoid(input) +def linear(input, weight, bias=None): + + r + tens_ops = (input, weight) + if not torch.jit.is_scripting(): + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function(linear, tens_ops, input, weight, bias=bias) + if input.dim() == 2 and bias is not None: + ret = torch.addmm(bias, input, weight.t()) + else: + output = input.matmul(weight.t()) + if bias is not None: + output += bias + ret = output + return ret +def bilinear(input1, input2, weight, bias=None): + + r + return torch.bilinear(input1, input2, weight, bias) def silu(input, inplace=False): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function(silu, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.silu_(input) + return torch._C._nn.silu(input) def hardswish(input: Tensor, inplace: bool = False) -> Tensor: + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function(hardswish, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.hardswish_(input) + return torch._C._nn.hardswish(input) +def _no_grad_embedding_renorm_(weight, input, max_norm, norm_type): + + with torch.no_grad(): + torch.embedding_renorm_(weight, input, max_norm, norm_type) +def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., + scale_grad_by_freq=False, sparse=False): + + r + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < weight.size(0), 'MSG' + elif padding_idx < 0: + assert padding_idx >= -weight.size(0), 'MSG' + padding_idx = weight.size(0) + padding_idx + else: + padding_idx = -1 + if max_norm is not None: + input = input.contiguous() + _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) + return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) +def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, + scale_grad_by_freq=False, mode='MSG', sparse=False, + per_sample_weights=None, include_last_offset=False): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, weight) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + embedding_bag, tens_ops, input, weight, offsets=offsets, max_norm=max_norm, + norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, mode=mode, + sparse=sparse, per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset) + + + + if weight.dtype == torch.long and input.is_floating_point(): + warnings.warn("MSG" + "MSG" + "MSG") + weight, input = input, weight if per_sample_weights is not None and input.size() != per_sample_weights.size(): + raise ValueError("MSG" + "MSG" + .format(per_sample_weights.shape, input.shape)) if input.dim() == 2: + if offsets is not None: + type_str = "MSG" + if not torch.jit.is_scripting(): + type_str = str(type(offsets)) + raise ValueError("MSG" + "MSG" + "MSG" + "MSG".format(type_str)) + offsets = torch.arange(0, input.numel(), input.size(1), + dtype=torch.long, device=input.device) input = input.reshape(-1) + if per_sample_weights is not None: + per_sample_weights = per_sample_weights.reshape(-1) + elif input.dim() == 1: + if offsets is None: + raise ValueError("MSG") + if offsets.dim() != 1: + raise ValueError("MSG") + else: + raise ValueError("MSG" + "MSG".format(input.dim())) + if mode == 'MSG': + mode_enum = 0 + elif mode == 'MSG': + mode_enum = 1 + elif mode == 'MSG': + mode_enum = 2 if scale_grad_by_freq: + raise ValueError("MSG") if sparse: + raise ValueError("MSG") else: + raise ValueError("MSG") if max_norm is not None: + _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) if per_sample_weights is not None and mode != 'MSG': + raise NotImplementedError("MSG" + "MSG" + "MSG" + .format(mode)) ret, _, _, _ = torch.embedding_bag( + weight, + input, + offsets, + scale_grad_by_freq, + mode_enum, + sparse, + per_sample_weights, + include_last_offset) + return ret +def _verify_batch_size(size): + + + + + + + + + + + size_prods = size[0] + for i in range(len(size) - 2): + size_prods *= size[i + 2] + if size_prods == 1: + raise ValueError('MSG'.format(size)) +def batch_norm(input, running_mean, running_var, weight=None, bias=None, + training=False, momentum=0.1, eps=1e-5): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + batch_norm, (input,), input, running_mean, running_var, weight=weight, + bias=bias, training=training, momentum=momentum, eps=eps) + if training: + _verify_batch_size(input.size()) return torch.batch_norm( + input, weight, bias, running_mean, running_var, + training, momentum, eps, torch.backends.cudnn.enabled + ) +def instance_norm(input, running_mean=None, running_var=None, weight=None, + bias=None, use_input_stats=True, momentum=0.1, eps=1e-5): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + instance_norm, (input,), input, running_mean=running_mean, + running_var=running_var, weight=weight, bias=bias, + use_input_stats=use_input_stats, momentum=momentum, eps=eps) + _verify_batch_size(input.size()) + return torch.instance_norm( + input, weight, bias, running_mean, running_var, + use_input_stats, momentum, eps, torch.backends.cudnn.enabled + ) +def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps) + return torch.layer_norm(input, normalized_shape, weight, bias, eps, + torch.backends.cudnn.enabled) +def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps) + _verify_batch_size([ + input.size(0) * input.size(1) // num_groups, num_groups] + + list(input.size()[2:])) + return torch.group_norm(input, num_groups, weight, bias, eps, + torch.backends.cudnn.enabled) +def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k) + dim = input.dim() + if dim < 3: + raise ValueError('MSG'.format(dim)) + div = input.mul(input).unsqueeze(1) + if dim == 3: + div = pad(div, (0, 0, size // 2, (size - 1) // 2)) + div = avg_pool2d(div, (size, 1), stride=1).squeeze(1) + else: + sizes = input.size() + div = div.view(sizes[0], 1, sizes[1], sizes[2], -1) + div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2)) + div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1) + div = div.view(sizes) + div = div.mul(alpha).add(k).pow(beta) + return input / div +def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, + reduction='MSG', zero_infinity=False): + + r + return torch.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction), + zero_infinity) +def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, + reduce=None, reduction='MSG'): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + nll_loss, tens_ops, input, target, weight=weight, size_average=size_average, + ignore_index=ignore_index, reduce=reduce, reduction=reduction) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + dim = input.dim() + if dim < 2: + raise ValueError('MSG'.format(dim)) if input.size(0) != target.size(0): + raise ValueError('MSG' + .format(input.size(0), target.size(0))) + if dim == 2: + ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) + elif dim == 4: + ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index) + else: + n = input.size(0) + c = input.size(1) + out_size = (n,) + input.size()[2:] + if target.size()[1:] != input.size()[2:]: + raise ValueError('MSG'.format( + out_size, target.size())) + input = input.contiguous() + target = target.contiguous() + if input.numel() > 0: + input = input.view(n, c, 1, -1) + else: + input = input.view(n, c, 0, 0) + if target.numel() > 0: + target = target.view(n, 1, -1) + else: + target = target.view(n, 0, 0) + reduction_enum = _Reduction.get_enum(reduction) + if reduction != 'MSG': + ret = torch._C._nn.nll_loss2d( + input, target, weight, reduction_enum, ignore_index) + else: + out = torch._C._nn.nll_loss2d( + input, target, weight, reduction_enum, ignore_index) + ret = out.view(out_size) + return ret +def poisson_nll_loss(input, target, log_input=True, full=False, size_average=None, eps=1e-8, + reduce=None, reduction='MSG'): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + poisson_nll_loss, tens_ops, input, target, log_input=log_input, full=full, + size_average=size_average, eps=eps, reduce=reduce, reduction=reduction) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + if reduction != 'MSG' and reduction != 'MSG' and reduction != 'MSG': + ret = input + raise ValueError(reduction + "MSG") ret = torch.poisson_nll_loss(input, target, log_input, full, eps, _Reduction.get_enum(reduction)) + return ret +def kl_div(input, target, size_average=None, reduce=None, reduction='MSG', log_target=False): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + kl_div, tens_ops, input, target, size_average=size_average, + reduce=reduce, reduction=reduction, log_target=log_target) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + if reduction == 'MSG': + warnings.warn("MSG" + "MSG" + "MSG") if reduction == 'MSG': + reduction_enum = _Reduction.get_enum('MSG') + else: + reduction_enum = _Reduction.get_enum(reduction) reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target) if reduction == 'MSG' and input.dim() != 0: + reduced = reduced / input.size()[0] return reduced +def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, + reduce=None, reduction='MSG'): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + cross_entropy, tens_ops, input, target, weight=weight, + size_average=size_average, ignore_index=ignore_index, reduce=reduce, + reduction=reduction) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) +def binary_cross_entropy(input, target, weight=None, size_average=None, + reduce=None, reduction='MSG'): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + binary_cross_entropy, tens_ops, input, target, weight=weight, + size_average=size_average, reduce=reduce, reduction=reduction) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if target.size() != input.size(): + raise ValueError("MSG" + "MSG".format(target.size(), input.size())) if weight is not None: + new_size = _infer_size(target.size(), weight.size()) + weight = weight.expand(new_size) return torch._C._nn.binary_cross_entropy( + input, target, weight, reduction_enum) +def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None, + reduce=None, reduction='MSG', pos_weight=None): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + binary_cross_entropy_with_logits, tens_ops, input, target, weight=weight, + size_average=size_average, reduce=reduce, reduction=reduction, + pos_weight=pos_weight) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) if not (target.size() == input.size()): + raise ValueError("MSG".format(target.size(), input.size())) return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum) +def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='MSG', beta=1.0): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + smooth_l1_loss, tens_ops, input, target, size_average=size_average, + reduce=reduce, reduction=reduction, beta=beta) + if not (target.size() == input.size()): + warnings.warn("MSG" + "MSG" + "MSG".format(target.size(), input.size()), + stacklevel=2) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) expanded_input, expanded_target = torch.broadcast_tensors(input, target) + return torch._C._nn.smooth_l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction), beta) +def l1_loss(input, target, size_average=None, reduce=None, reduction='MSG'): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + l1_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, + reduction=reduction) + if not (target.size() == input.size()): + warnings.warn("MSG" + "MSG" + "MSG".format(target.size(), input.size()), + stacklevel=2) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) +def mse_loss(input, target, size_average=None, reduce=None, reduction='MSG'): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + mse_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, + reduction=reduction) + if not (target.size() == input.size()): + warnings.warn("MSG" + "MSG" + "MSG".format(target.size(), input.size()), + stacklevel=2) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) expanded_input, expanded_target = torch.broadcast_tensors(input, target) + return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) +def margin_ranking_loss(input1, input2, target, margin=0, size_average=None, + reduce=None, reduction='MSG'): + + r + if not torch.jit.is_scripting(): + tens_ops = (input1, input2, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + margin_ranking_loss, tens_ops, input1, input2, target, margin=margin, + size_average=size_average, reduce=reduce, reduction=reduction) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if input1.dim() == 0 or input2.dim() == 0 or target.dim() == 0: + raise RuntimeError(("MSG" + "MSG".format(input1.size(), input2.size(), target.size()))) + return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum) +def hinge_embedding_loss(input, target, margin=1.0, size_average=None, + reduce=None, reduction='MSG'): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + hinge_embedding_loss, tens_ops, input, target, margin=margin, + size_average=size_average, reduce=reduce, reduction=reduction) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch.hinge_embedding_loss(input, target, margin, reduction_enum) +def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='MSG'): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + multilabel_margin_loss, tens_ops, input, target, size_average=size_average, + reduce=reduce, reduction=reduction) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum) +def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='MSG'): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + soft_margin_loss, tens_ops, input, target, size_average=size_average, + reduce=reduce, reduction=reduction) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch._C._nn.soft_margin_loss(input, target, reduction_enum) +def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, + reduce=None, reduction='MSG'): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + multilabel_soft_margin_loss, tens_ops, input, target, weight=weight, + size_average=size_average, reduce=reduce, reduction=reduction) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input)) if weight is not None: + loss = loss * weight loss = loss.sum(dim=1) / input.size(1) if reduction == 'MSG': + ret = loss + elif reduction == 'MSG': + ret = loss.mean() + elif reduction == 'MSG': + ret = loss.sum() + else: + ret = input + raise ValueError(reduction + "MSG") + return ret +def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, + reduce=None, reduction='MSG'): + + r + if not torch.jit.is_scripting(): + tens_ops = (input1, input2, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + cosine_embedding_loss, tens_ops, input1, input2, target, margin=margin, + size_average=size_average, reduce=reduce, reduction=reduction) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum) +def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=None, + reduce=None, reduction='MSG'): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, target) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + multi_margin_loss, tens_ops, input, target, p=p, margin=margin, + weight=weight, size_average=size_average, reduce=reduce, + reduction=reduction) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if p != 1 and p != 2: + raise ValueError('MSG') + if weight is not None: + if weight.dim() != 1: + raise ValueError('MSG') return torch._C._nn.multi_margin_loss(input, target, p, margin, weight, reduction_enum) +pixel_shuffle = _add_docstr(torch.pixel_shuffle, r) channel_shuffle = _add_docstr(torch.channel_shuffle, r) @_overload +def upsample(input, size=None, scale_factor=None, mode='MSG', align_corners=None): + + pass @_overload +def upsample(input, size=None, scale_factor=None, mode='MSG', align_corners=None): + + pass +def upsample(input, size=None, scale_factor=None, mode='MSG', align_corners=None): + r + warnings.warn("MSG") + return interpolate(input, size, scale_factor, mode, align_corners) @_overload +def interpolate(input, size=None, scale_factor=None, mode='MSG', align_corners=None, recompute_scale_factor=None): + + pass @_overload +def interpolate(input, size=None, scale_factor=None, mode='MSG', align_corners=None, recompute_scale_factor=None): + + pass @_overload +def interpolate(input, size=None, scale_factor=None, mode='MSG', align_corners=None, recompute_scale_factor=None): + + pass @_overload +def interpolate(input, size=None, scale_factor=None, mode='MSG', align_corners=None, recompute_scale_factor=None): + + pass def interpolate(input, size=None, scale_factor=None, mode='MSG', align_corners=None, recompute_scale_factor=None): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + interpolate, (input,), input, size=size, scale_factor=scale_factor, + mode=mode, align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor) if mode in ('MSG', 'MSG'): + if align_corners is not None: + raise ValueError("MSG" + "MSG") + else: + if align_corners is None: + warnings.warn("MSG" + "MSG" + "MSG" + "MSG".format(mode)) + align_corners = False dim = input.dim() - 2 + + + + if size is not None and scale_factor is not None: + raise ValueError('MSG') + elif size is not None: + assert scale_factor is None + scale_factors = None + if isinstance(size, (list, tuple)): + if len(size) != dim: + raise ValueError('MSG' + 'MSG'.format(dim, len(size))) + output_size = size + else: + output_size = [size for _ in range(dim)] + elif scale_factor is not None: + assert size is None + output_size = None + if isinstance(scale_factor, (list, tuple)): + if len(scale_factor) != dim: + raise ValueError('MSG' + 'MSG'.format(dim, len(scale_factor))) + scale_factors = scale_factor + else: + scale_factors = [scale_factor for _ in range(dim)] + else: + raise ValueError('MSG') if recompute_scale_factor is None: + if scale_factors is not None: + for scale in scale_factors: + if math.floor(scale) != scale: + warnings.warn("MSG" + "MSG" + "MSG" + "MSG" + "MSG") + break + elif recompute_scale_factor and size is not None: + raise ValueError("MSG") + + if mode == "MSG" and output_size is None: + recompute_scale_factor = True if recompute_scale_factor is not None and recompute_scale_factor: + if not torch.jit.is_scripting() and torch._C._get_tracing_state(): + output_size = [(torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], + dtype=torch.float32)).float())) for i in range(dim)] + else: + assert scale_factors is not None + output_size = [int(math.floor(float(input.size(i + 2)) * scale_factors[i])) for i in range(dim)] + scale_factors = None if input.dim() == 3 and mode == 'MSG': + return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors) + if input.dim() == 4 and mode == 'MSG': + return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) + if input.dim() == 5 and mode == 'MSG': + return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors) if input.dim() == 3 and mode == 'MSG': + assert output_size is not None + return adaptive_avg_pool1d(input, output_size) + if input.dim() == 4 and mode == 'MSG': + assert output_size is not None + return adaptive_avg_pool2d(input, output_size) + if input.dim() == 5 and mode == 'MSG': + assert output_size is not None + return adaptive_avg_pool3d(input, output_size) if input.dim() == 3 and mode == 'MSG': + assert align_corners is not None + return torch._C._nn.upsample_linear1d(input, output_size, align_corners, scale_factors) + if input.dim() == 4 and mode == 'MSG': + assert align_corners is not None + return torch._C._nn.upsample_bilinear2d(input, output_size, align_corners, scale_factors) + if input.dim() == 5 and mode == 'MSG': + assert align_corners is not None + return torch._C._nn.upsample_trilinear3d(input, output_size, align_corners, scale_factors) + if input.dim() == 4 and mode == 'MSG': + assert align_corners is not None + return torch._C._nn.upsample_bicubic2d(input, output_size, align_corners, scale_factors) if input.dim() == 3 and mode == 'MSG': + raise NotImplementedError("MSG") + if input.dim() == 3 and mode == 'MSG': + raise NotImplementedError("MSG") + if input.dim() == 4 and mode == 'MSG': + raise NotImplementedError("MSG") + if input.dim() == 4 and mode == 'MSG': + raise NotImplementedError("MSG") + if input.dim() == 5 and mode == 'MSG': + raise NotImplementedError("MSG") + if input.dim() == 5 and mode == 'MSG': + raise NotImplementedError("MSG") raise NotImplementedError("MSG" + "MSG" + "MSG".format(input.dim(), mode)) @_overload +def upsample_nearest(input, size=None, scale_factor=None): + + pass @_overload +def upsample_nearest(input, size=None, scale_factor=None): + + pass def upsample_nearest(input, size=None, scale_factor=None): + r + + warnings.warn("MSG") + return interpolate(input, size, scale_factor, mode='MSG') @_overload +def upsample_bilinear(input, size=None, scale_factor=None): + + pass @_overload +def upsample_bilinear(input, size=None, scale_factor=None): + + pass @_overload +def upsample_bilinear(input, size=None, scale_factor=None): + + pass @_overload +def upsample_bilinear(input, size=None, scale_factor=None): + + pass def upsample_bilinear(input, size=None, scale_factor=None): + r + + warnings.warn("MSG") + return interpolate(input, size, scale_factor, mode='MSG', align_corners=True) +GRID_SAMPLE_INTERPOLATION_MODES = { + 'MSG': 0, + 'MSG': 1, +} GRID_SAMPLE_PADDING_MODES = { + 'MSG': 0, + 'MSG': 1, + 'MSG': 2, +} +def grid_sample(input, grid, mode='MSG', padding_mode='MSG', align_corners=None): + + r + if not torch.jit.is_scripting(): + tens_ops = (input, grid) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + grid_sample, tens_ops, input, grid, mode=mode, padding_mode=padding_mode, + align_corners=align_corners) + if mode != 'MSG' and mode != 'MSG': + raise ValueError("MSG" + "MSG".format(mode)) + if padding_mode != 'MSG' and padding_mode != 'MSG' and padding_mode != 'MSG': + raise ValueError("MSG" + "MSG" + "MSG".format(padding_mode)) if mode == 'MSG': + mode_enum = 0 + else: + mode_enum = 1 if padding_mode == 'MSG': + padding_mode_enum = 0 + elif padding_mode == 'MSG': + padding_mode_enum = 1 + else: + padding_mode_enum = 2 if align_corners is None: + warnings.warn("MSG" + "MSG" + "MSG" + "MSG") + align_corners = False return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners) +def affine_grid(theta, size, align_corners=None): + + r + if not torch.jit.is_scripting(): + if type(theta) is not Tensor and has_torch_function((theta,)): + return handle_torch_function( + affine_grid, (theta,), theta, size, align_corners=align_corners) + if align_corners is None: + warnings.warn("MSG" + "MSG" + "MSG" + "MSG") + align_corners = False + if not theta.is_floating_point(): + raise ValueError("MSG" + .format(theta.dtype)) + + if len(size) == 4: + if theta.dim() != 3 or theta.shape[-2] != 2 or theta.shape[-1] != 3: + raise ValueError("MSG" + "MSG".format(size, theta.shape)) + spatial_size = size[-2:] + elif len(size) == 5: + if theta.dim() != 3 or theta.shape[-2] != 3 or theta.shape[-1] != 4: + raise ValueError("MSG" + "MSG".format(size, theta.shape)) + spatial_size = size[-3:] + else: + raise NotImplementedError("MSG" + "MSG" + "MSG".format(size)) + + if align_corners and min(spatial_size) == 1: + warnings.warn("MSG" + "MSG" + "MSG" + "MSG") + elif min(size) <= 0: + raise ValueError("MSG" + .format(size)) return torch.affine_grid_generator(theta, size, align_corners) +def _pad(input, pad, mode='MSG', value=0): + + r + if not torch.jit.is_scripting(): + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + _pad, (input,), input, pad, mode=mode, value=value) + assert len(pad) % 2 == 0, 'MSG' + assert len(pad) // 2 <= input.dim(), 'MSG' + if mode == 'MSG': + return _VF.constant_pad_nd(input, pad, value) + else: + assert value == 0, 'MSG's dimension {} is not supported"MSG"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."MSG"bias cannot be added to static key."MSG"bias cannot be added to static value."MSG"_grad_input_padding 'MSG' argument not provided. Default of 1 is used."MSG"input_size must have {} elements (got {})"MSG"requested an input grad size of {}, but valid sizes range "MSG"from {} to {} (for a grad_output of {})"MSG"grad.conv1d_input requires specifying an input_size"MSG"grad.conv2d_input requires specifying an input_size"MSG"grad.conv3d_input requires specifying an input_size"MSG"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "MSG"The distribution of values may be incorrect."MSG"negative_slope {} not a valid number"MSG"Unsupported nonlinearity {}"MSG"Only tensors with 2 dimensions are supported"MSG"Only tensors with 3, 4, or 5 dimensions are supported"MSG"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"MSG"Mode {} not supported, please use one of {}"MSG"Only tensors with 2 or more dimensions are supported"MSG"Only tensors with 2 dimensions are supported"MSG"nn.init.{} is now deprecated in favor of nn.init.{}."MSG"reduction='MSG' is deprecated, please use reduction='MSG' instead."MSG"{} is not a valid value for reduction"MSG"size_average and reduce args will be deprecated, please use reduction='MSG' instead."MSG"Input shape must be `(N, C, H, W)`!"MSG"Input shape must be `(N, C, D, H, W)`!"MSG"Input shape must be `(N, C, L)`!"MSG"Input shape must be `(N, C, H, W)`!"MSG"Input shape must be `(N, C, D, H, W)`!"MSG"keyword argument min_value is deprecated and rename to min_val"MSG"keyword argument max_value is deprecated and rename to max_val"MSG"embed_dim must be divisible by num_heads"MSG"cutoffs should be a sequence of unique, positive "MSG"integers sorted in an increasing order, where "MSG"each value is between 1 and n_classes-1"MSG"Target values should be in [0, {}], "MSG"but values in range [{}, {}] "MSG"were found. "MSG"qconfig"MSG"nn.Container is deprecated. All of it'MSG'{}'MSG'log_input'MSG'full'MSG'eps'MSG'reduction'MSG'mean'MSG'reduction'MSG'mean'MSG'reduction'MSG'mean'MSG'reduction'MSG'mean'MSG'mean'MSG'weight'MSG'pos_weight'MSG'margin'MSG'reduction'MSG'mean'MSG'reduction'MSG'mean'MSG'reduction'MSG'mean'MSG'reduction'MSG'mean'MSG'ignore_index'MSG'reduction'MSG'mean'MSG'reduction'MSG'mean'MSG'margin'MSG'reduction'MSG'mean'MSG'margin'MSG'reduction'MSG'mean'MSG'p'MSG'margin'MSG'reduction'MSG'mean'MSG'margin'MSG'p'MSG'eps'MSG'swap'MSG'reduction'MSG'mean'MSG'margin'MSG'swap'MSG'reduction'MSG'mean'MSG'blank'MSG'reduction'MSG'mean'MSG'T'MSG'Module'MSG'IncompatibleKeys'MSG'missing_keys'MSG'unexpected_keys'MSG''MSG'\n'MSG' 'MSG'\n'MSG'\n'MSG'Module'MSG'_buffers'MSG'.'MSG''MSG'_parameters'MSG'.'MSG''MSG'Module'MSG'.'MSG''MSG'Module'MSG'nn.Module.to only accepts floating point 'MSG'dtypes, but got desired dtype={}'MSG'Module'MSG'_forward_pre_hooks'MSG'_state_dict_hooks'MSG'_load_state_dict_pre_hooks'MSG'_non_persistent_buffers_set'MSG'Module'MSG'_parameters'MSG'_parameters'MSG'_buffers'MSG'_buffers'MSG'_modules'MSG'_modules'MSG'Module'MSG'_parameters'MSG'_modules'MSG'_buffers'MSG'T_destination'MSG''MSG'.'MSG'size mismatch for {}: copying a param with shape {} from checkpoint, 'MSG'the shape in current model is {}.'MSG'While copying the parameter named "MSG", 'MSG'whose dimensions in the model are {} and 'MSG'whose dimensions in the checkpoint are {}, 'MSG'an exception occurred : {}.'MSG'.'MSG'_metadata'MSG''MSG'.'MSG'Unexpected key(s) in state_dict: {}. 'MSG', 'MSG'"MSG"'MSG'Missing key(s) in state_dict: {}. 'MSG', 'MSG'"MSG"'MSG'Error(s) in loading state_dict for {}:\n\t{}'MSG''MSG'.'MSG''MSG''MSG''MSG'Module'MSG'Module'MSG'Module'MSG'Module'MSG''MSG'.'MSG''MSG'_is_replica'MSG''MSG'\n'MSG'('MSG'): 'MSG'('MSG'\n 'MSG'\n 'MSG'\n'MSG')'MSG'size'MSG'alpha'MSG'beta'MSG'k'MSG'{size}, alpha={alpha}, beta={beta}, k={k}'MSG'{size}, alpha={alpha}, beta={beta}, k={k}'MSG'normalized_shape'MSG'eps'MSG'elementwise_affine'MSG'weight'MSG'bias'MSG'{normalized_shape}, eps={eps}, 'MSG'elementwise_affine={elementwise_affine}'MSG'num_groups'MSG'num_channels'MSG'eps'MSG'affine'MSG'weight'MSG'bias'MSG'{num_groups}, {num_channels}, eps={eps}, 'MSG'affine={affine}'MSG'padding'MSG'value'MSG'constant'MSG'padding={}, value={}'MSG'padding'MSG'value'MSG'padding'MSG'reflect'MSG'{}'MSG'padding'MSG'replicate'MSG'{}'MSG'upscale_factor'MSG'upscale_factor={}'MSG'kernel_size'MSG'stride'MSG'padding'MSG'dilation'MSG'return_indices'MSG'ceil_mode'MSG'kernel_size={kernel_size}, stride={stride}, padding={padding}'MSG', dilation={dilation}, ceil_mode={ceil_mode}'MSG'kernel_size={}, stride={}, padding={}'MSG'kernel_size'MSG'stride'MSG'padding'MSG'ceil_mode'MSG'count_include_pad'MSG'kernel_size={}, stride={}, padding={}'MSG'kernel_size'MSG'stride'MSG'padding'MSG'ceil_mode'MSG'count_include_pad'MSG'divisor_override'MSG'kernel_size'MSG'stride'MSG'padding'MSG'ceil_mode'MSG'count_include_pad'MSG'divisor_override'MSG'padding'MSG'ceil_mode'MSG'count_include_pad'MSG'kernel_size'MSG'return_indices'MSG'output_size'MSG'output_ratio'MSG'_random_samples'MSG'kernel_size'MSG'return_indices'MSG'output_size'MSG'output_ratio'MSG'_random_samples'MSG'norm_type'MSG'kernel_size'MSG'stride'MSG'ceil_mode'MSG'norm_type={norm_type}, kernel_size={kernel_size}, stride={stride}, 'MSG'ceil_mode={ceil_mode}'MSG'output_size'MSG'return_indices'MSG'output_size={}'MSG'output_size'MSG'output_size={}'MSG'RNN_TANH'MSG'RNN_RELU'MSG'mode'MSG'input_size'MSG'hidden_size'MSG'num_layers'MSG'bias'MSG'batch_first'MSG'dropout'MSG'bidirectional'MSG'all_weights'MSG'LSTM'MSG'GRU'MSG'RNN_TANH'MSG'RNN_RELU'MSG'_reverse'MSG''MSG'weight_ih_l{}{}'MSG'weight_hh_l{}{}'MSG'bias_ih_l{}{}'MSG'bias_hh_l{}{}'MSG'input must have {} dimensions, got {}'MSG'input.size(-1) must be equal to input_size. Expected {}, got {}'MSG'Expected hidden size {}, got {}'MSG'{input_size}, {hidden_size}'MSG', num_layers={num_layers}'MSG', bias={bias}'MSG', batch_first={batch_first}'MSG', dropout={dropout}'MSG', bidirectional={bidirectional}'MSG'all_weights'MSG'all_weights'MSG'_reverse'MSG''MSG'weight_ih_l{}{}'MSG'weight_hh_l{}{}'MSG'bias_ih_l{}{}'MSG'bias_hh_l{}{}'MSG'nonlinearity'MSG'tanh'MSG'tanh'MSG'RNN_TANH'MSG'relu'MSG'RNN_RELU'MSG'LSTM'MSG'Expected hidden[0] size {}, got {}'MSG'Expected hidden[1] size {}, got {}'MSG'GRU'MSG'input_size'MSG'hidden_size'MSG'bias'MSG'bias_ih'MSG'bias_hh'MSG'{input_size}, {hidden_size}'MSG'bias'MSG', bias={bias}'MSG'nonlinearity'MSG', nonlinearity={nonlinearity}'MSG''MSG'input_size'MSG'hidden_size'MSG'bias'MSG'nonlinearity'MSG''MSG'[0]'MSG'[1]'MSG''MSG'num_embeddings'MSG'embedding_dim'MSG'padding_idx'MSG'max_norm'MSG'norm_type'MSG'scale_grad_by_freq'MSG'sparse'MSG'Padding_idx must be within num_embeddings'MSG'Padding_idx must be within num_embeddings'MSG'Shape of weight does not match num_embeddings and embedding_dim'MSG'{num_embeddings}, {embedding_dim}'MSG', padding_idx={padding_idx}'MSG', max_norm={max_norm}'MSG', norm_type={norm_type}'MSG', scale_grad_by_freq={scale_grad_by_freq}'MSG', sparse=True'MSG'Embeddings parameter is expected to be 2-dimensional'MSG'num_embeddings'MSG'embedding_dim'MSG'max_norm'MSG'norm_type'MSG'scale_grad_by_freq'MSG'mode'MSG'sparse'MSG'include_last_offset'MSG'mean'MSG'Shape of weight does not match num_embeddings and embedding_dim'MSG'{num_embeddings}, {embedding_dim}'MSG', max_norm={max_norm}'MSG', norm_type={norm_type}'MSG', scale_grad_by_freq={scale_grad_by_freq}'MSG', mode={mode}'MSG'mean'MSG'EmbeddingBag'MSG'Embeddings parameter is expected to be 2-dimensional'MSG'-inf'MSG'norm'MSG'norm'MSG'activation'MSG'activation'MSG'activation'MSG'activation'MSG'size'MSG'scale_factor'MSG'mode'MSG'align_corners'MSG'name'MSG'nearest'MSG'scale_factor='MSG'size='MSG', mode='MSG'nearest'MSG'bilinear'MSG'Input dimension should be at least {}'MSG'Expected more than 1 value per channel when training, got input size {}'MSG'Module'MSG'Identity'MSG'Linear'MSG'Conv1d'MSG'Conv2d'MSG'Conv3d'MSG'ConvTranspose1d'MSG'ConvTranspose2d'MSG'ConvTranspose3d'MSG'Threshold'MSG'ReLU'MSG'Hardtanh'MSG'ReLU6'MSG'Sigmoid'MSG'Tanh'MSG'Softmax'MSG'Softmax2d'MSG'LogSoftmax'MSG'ELU'MSG'SELU'MSG'CELU'MSG'GLU'MSG'GELU'MSG'Hardshrink'MSG'LeakyReLU'MSG'LogSigmoid'MSG'Softplus'MSG'Softshrink'MSG'MultiheadAttention'MSG'PReLU'MSG'Softsign'MSG'Softmin'MSG'Tanhshrink'MSG'RReLU'MSG'L1Loss'MSG'NLLLoss'MSG'KLDivLoss'MSG'MSELoss'MSG'BCELoss'MSG'BCEWithLogitsLoss'MSG'NLLLoss2d'MSG'PoissonNLLLoss'MSG'CosineEmbeddingLoss'MSG'CTCLoss'MSG'HingeEmbeddingLoss'MSG'MarginRankingLoss'MSG'MultiLabelMarginLoss'MSG'MultiLabelSoftMarginLoss'MSG'MultiMarginLoss'MSG'SmoothL1Loss'MSG'SoftMarginLoss'MSG'CrossEntropyLoss'MSG'Container'MSG'Sequential'MSG'ModuleList'MSG'ModuleDict'MSG'ParameterList'MSG'ParameterDict'MSG'AvgPool1d'MSG'AvgPool2d'MSG'AvgPool3d'MSG'MaxPool1d'MSG'MaxPool2d'MSG'MaxPool3d'MSG'MaxUnpool1d'MSG'MaxUnpool2d'MSG'MaxUnpool3d'MSG'FractionalMaxPool2d'MSG'LPPool1d'MSG'LPPool2d'MSG'LocalResponseNorm'MSG'BatchNorm1d'MSG'BatchNorm2d'MSG'BatchNorm3d'MSG'InstanceNorm1d'MSG'InstanceNorm2d'MSG'InstanceNorm3d'MSG'LayerNorm'MSG'GroupNorm'MSG'SyncBatchNorm'MSG'Dropout'MSG'Dropout2d'MSG'Dropout3d'MSG'AlphaDropout'MSG'FeatureAlphaDropout'MSG'ReflectionPad1d'MSG'ReflectionPad2d'MSG'ReplicationPad2d'MSG'ReplicationPad1d'MSG'ReplicationPad3d'MSG'CrossMapLRN2d'MSG'Embedding'MSG'EmbeddingBag'MSG'RNNBase'MSG'RNN'MSG'LSTM'MSG'GRU'MSG'RNNCellBase'MSG'RNNCell'MSG'LSTMCell'MSG'GRUCell'MSG'PixelShuffle'MSG'Upsample'MSG'UpsamplingNearest2d'MSG'UpsamplingBilinear2d'MSG'PairwiseDistance'MSG'AdaptiveMaxPool1d'MSG'AdaptiveMaxPool2d'MSG'AdaptiveMaxPool3d'MSG'AdaptiveAvgPool1d'MSG'AdaptiveAvgPool2d'MSG'AdaptiveAvgPool3d'MSG'TripletMarginLoss'MSG'ZeroPad2d'MSG'ConstantPad1d'MSG'ConstantPad2d'MSG'ConstantPad3d'MSG'Bilinear'MSG'CosineSimilarity'MSG'Unfold'MSG'Fold'MSG'AdaptiveLogSoftmaxWithLoss'MSG'TransformerEncoder'MSG'TransformerDecoder'MSG'TransformerEncoderLayer'MSG'TransformerDecoderLayer'MSG'Transformer'MSG'Flatten'MSG'Unflatten'MSG'Hardsigmoid'MSG'Hardswish'MSG'SiLU'MSG'TripletMarginWithDistanceLoss'MSG'x'MSG'x'MSG'Using -1 to represent CPU tensor is deprecated. Please use a 'MSG'device object or string instead, e.g., "MSG".'MSG'process_group'MSG'reducer'MSG'require_forward_param_sync'MSG'require_backward_grad_sync'MSG'cpu'MSG'All dicts must have the same number of keys'MSG'cpu'MSG'Broadcast function not implemented for CPU tensors'MSG'cpu'MSG'Gather function not implemented for CPU tensors'MSG'Was asked to gather along dimension 0, but all 'MSG'input tensors were scalars; will instead unsqueeze 'MSG'and return a vector.'MSG'replicate'MSG'scatter'MSG'parallel_apply'MSG'gather'MSG'data_parallel'MSG'DataParallel'MSG'DistributedDataParallel'MSG'zeros'MSG'qconfig must be provided for QAT module'MSG'qat.'MSG'.from_float only works for 'MSG'qconfig'MSG'Input float module must have qconfig defined'MSG'Input float module must have a valid qconfig'MSG'qconfig must be provided for QAT module'MSG' qat.'MSG'.from_float only works for 'MSG'qconfig'MSG'Input float module must have qconfig defined'MSG'Input float module must have a valid qconfig'MSG'Linear'MSG'Conv2d'MSG'zeros'MSG'zeros'MSG'zeros'MSG'zeros'MSG'zeros'MSG'zeros'MSG'nearest'MSG'nearest'MSG'bilinear'MSG'nearest'MSG'Unsupported dtype on dynamic quantized linear!'MSG'DynamicQuantizedLinear'MSG'in_features={}, out_features={}, dtype={}'MSG', qscheme={}'MSG'version'MSG'nn.quantized.dynamic.Linear.from_float only works for nn.Linear'MSG'qconfig'MSG'Input float module must have qconfig defined'MSG'The only supported dtypes for dynamic quantized linear are qint8 and float16'MSG'Unsupported dtype specified for dynamic quantized Linear!'MSG'param'MSG'param'MSG'LSTM'MSG'DynamicQuantizedRNN'MSG'{input_size}, {hidden_size}'MSG', num_layers={num_layers}'MSG', bias={bias}'MSG', batch_first={batch_first}'MSG', dropout={dropout}'MSG', bidirectional={bidirectional}'MSG'\n'MSG'('MSG'): 'MSG'('MSG'\n 'MSG'\n 'MSG'\n'MSG')'MSG'input must have {} dimensions, got {}'MSG'input.size(-1) must be equal to input_size. Expected {}, got {}'MSG'Expected hidden size {}, got {}'MSG'Expected hidden size {}, got {}'MSG'version'MSG'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU'MSG'qconfig'MSG'Input float module must have qconfig defined'MSG'Unsupported dtype for dynamic RNN quantization: {}'MSG'LSTM'MSG'Only LSTM is supported for QuantizedRNN for now'MSG'_reverse'MSG''MSG'weight_{}_l{}{}'MSG'bias_{}_l{}{}'MSG'ih'MSG'hh'MSG'Unsupported dtype specified for dynamic quantized LSTM!'MSG'weight'MSG'bias'MSG'_reverse'MSG''MSG'weight_ih_l{layer_idx}{suffix}'MSG'weight_hh_l{layer_idx}{suffix}'MSG'weight'MSG'weight'MSG'bias_ih_l{layer_idx}{suffix}'MSG'bias_hh_l{layer_idx}{suffix}'MSG'bias'MSG'bias'MSG'weight'MSG'bias'MSG'forward'MSG'forward_packed'MSG'forward_tensor'MSG'LSTM'MSG'DynamicQuantizedLSTM'MSG'Expected hidden[0] size {}, got {}'MSG'Expected hidden[1] size {}, got {}'MSG'input_size'MSG'hidden_size'MSG'bias'MSG'bias_ih'MSG'bias_hh'MSG'DynamicQuantizedRNNBase'MSG'{input_size}, {hidden_size}'MSG'bias'MSG', bias={bias}'MSG'nonlinearity'MSG', nonlinearity={nonlinearity}'MSG''MSG'nn.quantized.dynamic.RNNCellBase.from_float \ + only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell'MSG'qconfig'MSG'Input float module must have qconfig defined'MSG'Unsupported dtype for dynamic RNN quantization: {}'MSG'Only LSTMCell, GRUCell and RNNCell \ + are supported for QuantizedRNN for now'MSG'weight'MSG'bias'MSG'weight'MSG'weight_ih'MSG'weight'MSG'weight_hh'MSG'bias'MSG'bias_ih'MSG'bias'MSG'bias_hh'MSG'weight'MSG'bias'MSG'_packed_weight_ih'MSG'_packed_weight_hh'MSG'_packed_weight_ih'MSG'_packed_weight_hh'MSG'input_size'MSG'hidden_size'MSG'bias'MSG'nonlinearity'MSG'DynamicQuantizedRNNCell'MSG''MSG'DynamicQuantizedLSTMCell'MSG'[0]'MSG'[1]'MSG'DynamicQuantizedGRUCell'MSG''MSG'Linear'MSG'LSTM'MSG'LSTMCell'MSG'RNNCell'MSG'GRUCell'MSG'QuantizedReLU'MSG'QuantizedReLU6'MSG'QuantizedHardswish'MSG'QuantizedELU'MSG'QuantizedBatchNorm2d'MSG'QuantizedBatchNorm3d'MSG'zeros'MSG'zeros'MSG'in_channels must be divisible by groups'MSG'out_channels must be divisible by groups'MSG'{in_channels}, {out_channels}, kernel_size={kernel_size}'MSG', stride={stride}, scale={scale}, zero_point={zero_point}'MSG', padding={padding}'MSG', dilation={dilation}'MSG', output_padding={output_padding}'MSG', groups={groups}'MSG', bias=False'MSG'weight'MSG'bias'MSG'scale'MSG'zero_point'MSG'weight'MSG'bias'MSG'weight'MSG'bias'MSG'scale'MSG'scale'MSG'zero_point'MSG'zero_point'MSG'Weight observer must have a dtype of qint8'MSG'zeros'MSG'QuantizedConv1d'MSG' nnq.'MSG'.from_float only works for 'MSG'qconfig'MSG'Input float module must have qconfig defined.'MSG'zeros'MSG'QuantizedConv2d'MSG'weight_fake_quant'MSG'activation_post_process'MSG'Input QAT module must have observer attached'MSG' nnq.'MSG'.from_float only works for 'MSG'qconfig'MSG'Input float module must have qconfig defined.'MSG'zeros'MSG'QuantizedConv3d'MSG' nnq.'MSG'.from_float only works for 'MSG'qconfig'MSG'Input float module must have qconfig defined.'MSG'zeros'MSG'Only "MSG" padding mode is supported for {}'MSG' nnq.'MSG'.from_float only works for 'MSG'qconfig'MSG'Input float module must have qconfig defined.'MSG'Weight observer must have a dtype of qint8'MSG'zeros'MSG'QuantizedConvTranpose1d'MSG'zeros'MSG'QuantizedConvTranpose2d'MSG'Unsupported dtype on quantized embedding!'MSG'Unsupported dtype on quantized embedding!'MSG'Unsupported dtype on quantized embedding!'MSG'dtype'MSG'_packed_weight'MSG'version'MSG'dtype'MSG'dtype'MSG'_packed_weight'MSG'_packed_weight'MSG'Shape of weight does not match num_embeddings and embedding_dim'MSG'QuantizedEmbedding'MSG'num_embeddings={}, embedding_dim={}, dtype={}, qscheme={}'MSG'nnq.'MSG'.from_float only works for 'MSG'qconfig'MSG'Embedding input float module must have qconfig defined'MSG'The only supported dtype for nnq.Embedding is torch.quint8'MSG'sum'MSG'QuantizedEmbeddingBag'MSG'nnq.'MSG'.from_float only works for 'MSG'qconfig'MSG'EmbeddingBag input float module must have qconfig defined'MSG'The only supported dtype for nnq.EmbeddingBag is torch.quint8'MSG'scale'MSG'zero_point'MSG'scale'MSG'zero_point'MSG'QFunctional'MSG'scale={}, zero_point={}'MSG'Unsupported dtype on dynamic quantized linear!'MSG'Unsupported dtype on dynamic quantized linear!'MSG'dtype'MSG'_packed_params'MSG'version'MSG'dtype'MSG'dtype'MSG'weight'MSG'bias'MSG'weight'MSG'bias'MSG'_packed_params'MSG'_packed_params'MSG'Unsupported dtype specified for quantized Linear!'MSG'QuantizedLinear'MSG'in_features={}, out_features={}, scale={}, zero_point={}, qscheme={}'MSG'scale'MSG'zero_point'MSG'scale'MSG'scale'MSG'zero_point'MSG'zero_point'MSG'version'MSG'weight'MSG'bias'MSG'_packed_params.weight'MSG'_packed_params.bias'MSG'weight_fake_quant'MSG' nnq.'MSG'.from_float only works for 'MSG'qconfig'MSG'Input float module must have qconfig defined'MSG'Weight observer must have dtype torch.qint8'MSG'QuantizedLayerNorm'MSG'num_groups'MSG'num_channels'MSG'eps'MSG'affine'MSG'QuantizedGroupNorm'MSG'QuantizedInstanceNorm1d'MSG'QuantizedInstanceNorm2d'MSG'QuantizedInstanceNorm3d'MSG'\n'MSG'('MSG'): 'MSG'('MSG'\n 'MSG'\n 'MSG'\n'MSG')'MSG'scale'MSG'zero_point'MSG'activation_post_process'MSG'scale={}, zero_point={}, dtype={}'MSG'BatchNorm2d'MSG'BatchNorm3d'MSG'Conv1d'MSG'Conv2d'MSG'Conv3d'MSG'ConvTranspose1d'MSG'ConvTranspose2d'MSG'DeQuantize'MSG'Linear'MSG'MaxPool2d'MSG'Quantize'MSG'ReLU'MSG'ReLU6'MSG'Hardswish'MSG'ELU'MSG'LayerNorm'MSG'GroupNorm'MSG'InstanceNorm1d'MSG'InstanceNorm2d'MSG'InstanceNorm3d'MSG'Embedding'MSG'EmbeddingBag'MSG'FloatFunctional'MSG'QFunctional'MSG'expected torch.Tensor, but got: {}'MSG'Found two parameters on different devices, 'MSG'this is currently not supported.'MSG'BasePruningMethod'MSG'"MSG" need to have the attribute `dim` defined.'MSG'Index is out of bounds for tensor with dimensions {}'MSG'Only "MSG" PRUNING_TYPE supported for 'MSG'PackedSequence'MSG'data'MSG'batch_sizes'MSG'sorted_indices'MSG'unsorted_indices'MSG'data'MSG'batch_sizes'MSG'sorted_indices'MSG'unsorted_indices'MSG'cuda'MSG'cpu'MSG'cpu'MSG'device'MSG'dtype'MSG'cpu'MSG'spectral_norm'MSG'{}'MSG'{}'MSG't know to interpret Constant node"MSG"Failed to export an ONNX attribute 'MSG', since it'MSG't know to interpret ListConstruct node"MSG"Unexpected node type: {}"MSG"ONNX symbolic expected a constant value of the {} argument, got `{}`"MSG"prim::ListConstruct"MSG"prim::ListConstruct"MSG"ONNX export failed on "MSG" because "MSG" not supported"MSG"ONNX export failed on {}, which is not implemented for opset {}. "MSG"Try exporting with other opset versions."MSG"Sort"MSG"Out parameter is not supported"MSG"Shape"MSG"Gather"MSG"Constant"MSG"Sort"MSG"Ascending is not supported"MSG"TopK"MSG"TopK"MSG"TopK"MSG"Out parameter is not supported"MSG"Constant"MSG"Reshape"MSG"Constant"MSG"TopK"MSG"Ascending is not supported"MSG"TopK"MSG"TopK"MSG"onnx:Resize"MSG"onnx:Upsample"MSG"You are trying to export the model with "MSG" for ONNX opset version "MSG""MSG". "MSG"This operator might cause results to not match the expected results by PyTorch.\n"MSG"ONNX'MSG's Interpolation until opset 11. "MSG"Attributes to determine how to transform the input were added in onnx:Resize in opset 11 "MSG"to support Pytorch'MSG'"MSG"Constant"MSG"Concat"MSG"Unsqueeze"MSG"Unsqueeze"MSG"Constant"MSG"Slice"MSG"Gather"MSG"Slice"MSG"Unsqueeze"MSG"Mul"MSG"ReduceSum"MSG"ReduceMean"MSG"ReduceMax"MSG"Unsqueeze"MSG"Concat"MSG"ONNX defines [0, 255] for quint8 and [-128, 127] for qint8, got [{}, {}]"MSG"DequantizeLinear"MSG"QuantizeLinear"MSG"Constant"MSG"Constant"MSG"Clip"MSG"Cast"MSG"Clip"MSG"Gather"MSG"ATen"MSG"Shape"MSG"Unsqueeze"MSG"Concat"MSG"Shape"MSG"Unsqueeze"MSG"Shape"MSG"Concat"MSG"Reshape"MSG"ConstantOfShape"MSG"Shape"MSG"ScatterND"MSG"ScatterND"MSG"pixel_shuffle"MSG"only support 4d input"MSG"DepthToSpace"MSG"CRD"MSG"asymmetric"MSG"nearest"MSG"align_corners"MSG"pytorch_half_pixel"MSG"Constant"MSG"Shape"MSG"Cast"MSG"Long"MSG"Concat"MSG"Constant"MSG"Resize"MSG"floor"MSG"Resize"MSG"floor"MSG"nearest"MSG"nearest"MSG"nearest"MSG"linear"MSG"linear"MSG"linear"MSG"cubic"MSG"asymmetric"MSG"nearest"MSG"align_corners"MSG"pytorch_half_pixel"MSG"Constant"MSG"Shape"MSG"Cannot verify if the output_size is a scalar "MSG"while exporting interpolate. Assuming that it is not a scalar."MSG"interpolate (with a scalar output_size)"MSG"missing input shape (try giving an array of output_size values)"MSG"Concat"MSG"Cast"MSG"Concat"MSG"Constant"MSG"Resize"MSG"floor"MSG"interpolate (with scales)"MSG"missing input shape"MSG"Resize"MSG"floor"MSG"gather"MSG"sparse_grad == True"MSG"ATen"MSG"gather"MSG"GatherElements"MSG"ATen"MSG"scatter"MSG"ScatterElements"MSG"Cast"MSG"ScatterElements"MSG"Constant"MSG"Cast"MSG"CumSum"MSG"onnx::SplitToSequence"MSG"SequenceLength"MSG"Size"MSG"SequenceAt"MSG"SequenceInsert"MSG"prim::ListConstruct"MSG"add"MSG"does not support adding dynamic tensor list to another"MSG"SequenceInsert"MSG"SequenceInsert"MSG"SequenceErase"MSG"ConcatFromSequence"MSG"ConcatFromSequence"MSG"Unique"MSG"Pad"MSG"Constant"MSG"AveragePool"MSG"Unique"MSG"Round"MSG"SplitToSequence"MSG"Unsqueeze"MSG"Constant"MSG"Constant"MSG"Add"MSG"Slice"MSG"SequenceAt"MSG"Constant"MSG"SplitToSequence"MSG"Constant"MSG"Constant"MSG"Sub"MSG"Mul"MSG"Constant"MSG"Constant"MSG"Cast"MSG"Concat"MSG"ConstantOfShape"MSG"Reshape"MSG"Constant"MSG"Transpose"MSG"Reshape"MSG"Constant"MSG"Cast"MSG"constant"MSG"Pad"MSG"reflect"MSG"Pad"MSG"edge"MSG"Pad"MSG"Det"MSG"Constant"MSG"Constant"MSG"Range"MSG"Range"MSG"Constant"MSG"Range"MSG"Unknown aten::arange signature taking "MSG" arguments."MSG"Gather"MSG"Constant"MSG"_caffe2::Range"MSG"Shape"MSG"Squeeze"MSG"Constant"MSG"Constant"MSG"Equal"MSG"If"MSG"Squeeze"MSG"Identity"MSG"Unsqueeze"MSG"Gemm"MSG"ATen"MSG"index_fill"MSG"ATen"MSG"index_copy"MSG"Cast"MSG"RIGHT"MSG"Cast"MSG"Cast"MSG"LEFT"MSG"Cast"MSG"Add"MSG"Constant"MSG"Sub"MSG"Constant"MSG"Range"MSG"Constant"MSG"Constant"MSG"Constant"MSG"Add"MSG"Constant"MSG"Pad"MSG"Constant"MSG"Constant"MSG"Mul"MSG"Constant"MSG"Concat"MSG"Unsqueeze"MSG"Unsqueeze"MSG"Constant"MSG"Constant"MSG"Constant"MSG"Gather"MSG"Gather"MSG"Transpose"MSG"Reshape"MSG"Add"MSG"dim"MSG"ONNX and PyTorch use different strategies to split the input. "MSG"Input rank must be known at export time."MSG"Flatten"MSG"Flatten"MSG"Constant"MSG"Constant"MSG"Unsqueeze"MSG"Constant"MSG"Concat"MSG"Constant"MSG"Loop"MSG"Gather"MSG"Gather"MSG"Unsqueeze"MSG"Unsqueeze"MSG"Slice"MSG"Gather"MSG"Slice"MSG"Unsqueeze"MSG"Mul"MSG"ReduceSum"MSG"ReduceMean"MSG"ReduceMax"MSG"Einsum"MSG"dropout"MSG"Constant"MSG"Constant"MSG"Dropout"MSG"NegativeLogLikelihoodLoss"MSG"NegativeLogLikelihoodLoss"MSG"Cast"MSG"Celu"MSG"Cast"MSG"Celu"MSG"Constant"MSG"Constant"MSG"Pow"MSG"scan"MSG"expand"MSG"expand_as"MSG"meshgrid"MSG"adaptive_max_pool1d"MSG"adaptive_max_pool2d"MSG"adaptive_max_pool3d"MSG"max_pool1d_with_indices"MSG"max_pool2d_with_indices"MSG"max_pool3d_with_indices"MSG"Multidirectional broadcasting is not supported in opset 7. "MSG"This might cause the onnx model to be incorrect, if inputs to max operators "MSG"have different shapes"MSG"Multidirectional broadcasting is not supported in opset 7. "MSG"This might cause the onnx model to be incorrect, if inputs to min operators "MSG"have different shapes"MSG"nonzero"MSG"where"MSG"scatter"MSG"scatter_add"MSG"erf"MSG"sign"MSG"isnan"MSG"gather"MSG"arange"MSG"masked_fill"MSG"index_fill"MSG"index_copy"MSG"align_corners == True"MSG"torch._C.Value (output_size) indexing"MSG"Upsample"MSG"nearest"MSG"nearest"MSG"nearest"MSG"linear"MSG"linear"MSG"linear"MSG"interpolate"MSG"align_corners == True"MSG"interpolate"MSG"dynamic scales in opset 8"MSG"interpolate"MSG"dynamic size in opset 8"MSG"Upsample"MSG"Only floating datatype is supported for these operators: "MSG"{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "MSG"the onnx model to be incorrect, if inputs have integer datatypes."MSG"Greater"MSG"Less"MSG"MatMul"MSG"MatMul"MSG"Unsqueeze"MSG"PRelu"MSG"PRelu"MSG"Gemm"MSG"Gemm"MSG"Gemm"MSG"Gemm"MSG"Flatten"MSG"Flatten"MSG"Flatten"MSG"Flatten"MSG"ConstantFill"MSG"Float"MSG"ConstantFill"MSG"Shape"MSG"Shape"MSG"Constant"MSG"Shape"MSG"Constant"MSG"Constant"MSG"Tile"MSG"prim::Constant"MSG"add"MSG"alpha != 1"MSG"Add"MSG"sub"MSG"alpha != 1"MSG"Sub"MSG"Mul"MSG"Cast"MSG"Cast"MSG"Cast"MSG"Cast"MSG"Div"MSG"Cast"MSG"Div"MSG"Cast"MSG"Div"MSG"Cast"MSG"Cast"MSG"Div"MSG"Div"MSG"Concat"MSG"Unsqueeze"MSG"Concat"MSG"Constant"MSG"Gemm"MSG"MatMul"MSG"MatMul"MSG"Gemm"MSG"Neg"MSG"Sqrt"MSG"Div"MSG"Tanh"MSG"Sin"MSG"Cos"MSG"Tan"MSG"Asin"MSG"Acos"MSG"Atan"MSG"Sigmoid"MSG"Sign"MSG"Slice"MSG"Unknown aten::{} signature"MSG"dtype"MSG"dtype"MSG"dtype"MSG"ATen"MSG"cumsum"MSG"ATen"MSG"_sample_dirichlet"MSG"ATen"MSG"_standard_gamma"MSG"Transpose"MSG"Constant"MSG"Constant"MSG"Constant"MSG"Equal"MSG"Expand"MSG"Shape"MSG"Expand"MSG"Gather"MSG"ATen"MSG"embedding_bag"MSG"Shape"MSG"Constant"MSG"Transpose"MSG"ATen"MSG"transpose"MSG"Transpose"MSG"Constant"MSG"Reshape"MSG"Shape"MSG"Reshape"MSG"Split"MSG"Split"MSG"Split"MSG"Split"MSG"Split"MSG"Squeeze"MSG"Squeeze"MSG"Gather"MSG"Squeeze"MSG"ONNX export squeeze with negative axis "MSG" might cause the onnx model to be incorrect. "MSG"Negative axis is not supported in ONNX. "MSG"Axis is converted to "MSG" based on input shape at export time. "MSG"Passing an tensor of different rank in execution will be incorrect."MSG"This model contains a squeeze operation on dimension "MSG" on an input "MSG"with unknown shape. Note that if the size of dimension "MSG" of the input "MSG"is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on "MSG"non-singleton dimensions, it is recommended to export this model using opset "MSG"version 11 or higher."MSG"Squeeze"MSG"This model contains a squeeze operation on dimension "MSG". The size of "MSG"this dimension in the given input is "MSG". The model will "MSG"be exported without the squeeze node. If the model is intended to be used with dynamic "MSG"input shapes, please use opset version 11 to "MSG"export the model."MSG"This model contains a squeeze operation on dimension "MSG". If the model is "MSG"intended to be used with dynamic input shapes, please use opset version 11 to export the model."MSG"Squeeze"MSG"Unsqueeze"MSG"PRelu"MSG"Relu"MSG"Ceil"MSG"Floor"MSG"Size"MSG"threshold"MSG"non-zero threshold"MSG"threshold"MSG"non-zero value"MSG"Relu"MSG"LeakyRelu"MSG"Transpose"MSG"Cast"MSG"Transpose"MSG"Cast"MSG"beta"MSG"has to be 1"MSG"input size not accessible"MSG"dilation"MSG"MaxPool"MSG"MaxPool"MSG"MaxPool"MSG"max_pool1d"MSG"max_pool2d"MSG"max_pool3d"MSG"max_pool1d_with_indices"MSG"max_pool2d_with_indices"MSG"max_pool3d_with_indices"MSG"input size not accessible"MSG"Pad"MSG"AveragePool"MSG"AveragePool"MSG"GlobalAveragePool"MSG"GlobalMaxPool"MSG"GlobalMaxPool"MSG"MaxPool"MSG"AveragePool"MSG"AveragePool"MSG"AveragePool"MSG"MaxPool"MSG"MaxPool"MSG"MaxPool"MSG"constant"MSG"Pad"MSG"reflect"MSG"Pad"MSG"edge"MSG"Pad"MSG"align_corners == True"MSG"Upsample"MSG"nearest"MSG"nearest"MSG"nearest"MSG"linear"MSG"linear"MSG"linear"MSG"Upsample"MSG"bitwise_not"MSG"non-bool tensor"MSG"Not"MSG"Cast"MSG"Not"MSG"Equal"MSG"Equal"MSG"Cast"MSG"Cast"MSG"Greater"MSG"Cast"MSG"Cast"MSG"Less"MSG"Cast"MSG"Cast"MSG"Cast"MSG"Cast"MSG"Cast"MSG"Constant"MSG"Where"MSG"dim"MSG"ONNX and PyTorch use different strategies to split the input. "MSG"Input rank must be known at export time."MSG"Transpose"MSG"LogSoftmax"MSG"Cast"MSG"Transpose"MSG"kernel_shape_i"MSG"strides_i"MSG"pads_i"MSG"dilations_i"MSG"group_i"MSG"output_padding_i"MSG"ConvTranspose"MSG"Conv"MSG"Add"MSG"batch_norm"MSG"Constant"MSG"Constant"MSG"BatchNormalization"MSG"batch_norm_dead_output-"MSG"batch_norm_dead_output-"MSG"ATen"MSG"layer_norm"MSG"Constant"MSG"Constant"MSG"ReduceMean"MSG"ReduceMean"MSG"Div"MSG"Constant"MSG"Constant"MSG"InstanceNormalization"MSG"ATen"MSG"unfold"MSG"Unsqueeze"MSG"Transpose"MSG"Concat"MSG"Unfold"MSG"input size not accessible"MSG"scale"MSG"does not support scale in Elu"MSG"input_scale"MSG"does not support input_scale in Elu"MSG"Elu"MSG"Selu"MSG"Constant"MSG"Reshape"MSG"Constant"MSG"Gather"MSG"ATen"MSG"ATen"MSG"index_fill"MSG"ATen"MSG"index_copy"MSG"Cast"MSG"ATen"MSG"type_as"MSG"ATen"MSG"cosine_similarity"MSG"Abs"MSG"Log"MSG"Cast"MSG"Cast"MSG"Pow"MSG"Cast"MSG"Clip"MSG"Clip"MSG"Clip"MSG"ReduceMax"MSG"Max"MSG"ReduceMax"MSG"ReduceMin"MSG"Min"MSG"ReduceMin"MSG"Exp"MSG"dropout"MSG"Dropout is a training op and should not be exported in inference mode. "MSG"Make sure to call eval() on the model, and to export it with param training=False."MSG"Dropout"MSG"training mode"MSG"feature_dropout"MSG"alpha_dropout"MSG"feature_alpha_dropout"MSG"ReduceL1"MSG"ReduceL2"MSG"ONNX export only p-norms with p of 1 or 2"MSG"ATen"MSG"conv_tbc"MSG"ATen"MSG"_unique"MSG"ATen"MSG"_unique2"MSG"Cast"MSG"Constant"MSG"Reshape"MSG"Cast"MSG"Concat"MSG"Cast"MSG"ConstantOfShape"MSG"Shape"MSG"ConstantOfShape"MSG"ConstantOfShape"MSG"Shape"MSG"ConstantOfShape"MSG"Constant"MSG"ConstantOfShape"MSG"Constant"MSG"Shape"MSG"ConstantOfShape"MSG"Concat"MSG"Unsqueeze"MSG"Unsqueeze"MSG"EyeLike"MSG"step!=1 is currently not supported"MSG"Unsqueeze"MSG"Unsqueeze"MSG"Unsqueeze"MSG"DynamicSlice"MSG"Unknown aten::slice signature"MSG"Clip"MSG"ONNX export unsqueeze with negative axis "MSG" might cause the onnx model to be incorrect. "MSG"Negative axis is not supported in ONNX. "MSG"Axis is converted to "MSG" based on input shape at export time. "MSG"Passing an tensor of different rank in execution will be incorrect."MSG"Unsqueeze"MSG"Sort"MSG"Out parameter is not supported for sort"MSG"Sort"MSG"input size not accessible"MSG"TopK"MSG"Shape"MSG"ReduceProd"MSG"TopK"MSG"Out parameter is not supported for topk"MSG"TopK"MSG"Ascending TopK is not supported"MSG"TopK"MSG"Cast"MSG"Cast"MSG"Cast"MSG"Cast"MSG"Cast"MSG"Unknown aten::to signature"MSG"Expand"MSG"Tile"MSG"pixel_shuffle"MSG"only support 4d input"MSG"Constant"MSG"Transpose"MSG"Constant"MSG"Exporting a model to ONNX with a batch_size other than 1, "MSG"with a variable length with "MSG" can cause an error "MSG"when running the ONNX model with a different batch size. "MSG"Make sure to save the model with a batch size of 1, "MSG"or define the initial states (h0/c0) as inputs of the model. "MSG"RNN/GRU/LSTM"MSG"dropout in training mode"MSG"Gather"MSG"Constant"MSG"_caffe2::Range"MSG"onnx memory_format support is not implemented"MSG"Lengths must be a Tensor for ONNX export"MSG"prim::PackPadded"MSG"prim::PadPacked"MSG"ConstantOfShape"MSG"is"MSG"randn"MSG"ConstantOfShape"MSG"is"MSG"rand"MSG"dim"MSG"ONNX and PyTorch use different strategies to split the input. "MSG"Input rank must be known at export time."MSG"Flatten"MSG"Flatten"MSG"Constant"MSG"Constant"MSG"Scatter"MSG"Cast"MSG"Scatter"MSG"scatter_add"MSG"input size not accessible"MSG"Constant"MSG"Constant"MSG"Constant"MSG"OneHot"MSG"gather"MSG"sparse_grad == True"MSG"Constant"MSG"Constant"MSG"Cast"MSG"OneHot"MSG"Mul"MSG"Unsqueeze"MSG"ReduceSum"MSG"Mul"MSG"ReduceMean"MSG"ReduceMean"MSG"ReduceMean"MSG"ReduceMean"MSG"Mul"MSG"Abs"MSG"Sub"MSG"Mul"MSG"Constant"MSG"Div"MSG"Constant"MSG"Sqrt"MSG"std"MSG"Unknown input rank. Cannot compute std along dimensions."MSG"ATen"MSG"arange"MSG"Unsqueeze"MSG"Squeeze"MSG"Cast"MSG"Unsqueeze"MSG"Unsqueeze"MSG"Unsqueeze"MSG"Div"MSG"Sub"MSG"Squeeze"MSG"Add"MSG"Mul"MSG"Cast"MSG"Unsqueeze"MSG"Squeeze"MSG"Cast"MSG"Unsqueeze"MSG"Unsqueeze"MSG"Sub"MSG"Add"MSG"Squeeze"MSG"Cast"MSG"Unsqueeze"MSG"Unsqueeze"MSG"Unsqueeze"MSG"Div"MSG"Sub"MSG"Squeeze"MSG"Add"MSG"Mul"MSG"Cast"MSG"Unknown aten::arange signature taking "MSG" arguments."MSG"ATen"MSG"index"MSG"Byte"MSG"Bool"MSG"Exporting masked indices are only supported after ONNX opset 9."MSG"Exporting aten::index operator with indices of type Byte. "MSG"Only 1-D indices are supported. In any other case, "MSG"this will produce an incorrect ONNX graph."MSG"Unsupported aten::index operator of advanced indexing on tensor of unknown rank, "MSG"try turning on shape and type propagate during export: "MSG"torch.onnx._export(..., propagate=True)."MSG"Exporting aten::index operator of advanced indexing in opset "MSG" is achieved by combination of multiple ONNX operators, "MSG"including Reshape, Transpose, Concat, and Gather. "MSG"If indices include negative values, the exported graph will produce incorrect results."MSG"Gather"MSG"Constant"MSG"Transpose"MSG"Flatten"MSG"Mul"MSG"Add"MSG"Mul"MSG"Constant"MSG"Concat"MSG"Reshape"MSG"Transpose"MSG"Concat"MSG"Concat"MSG"Reshape"MSG"Multinomial"MSG"generator is not supported for multinomial"MSG"Multinomial"MSG"replacement=False when num_samples > 1 is not supported for multinomial"MSG"Multinomial"MSG"Cast"MSG"Cast"MSG"Constant"MSG"Shape"MSG"Concat"MSG"Constant"MSG"Concat"MSG"Expand"MSG"prim::ListConstruct"MSG"Div"MSG"Floor"MSG"Mul"MSG"Sub"MSG"ATen"MSG"group_norm"MSG"Constant"MSG"Constant"MSG"InstanceNormalization"MSG"Shape"MSG"Constant"MSG"Constant"MSG"Unsqueeze"MSG"Unsqueeze"MSG"Div"MSG"Mul"MSG"ATen"MSG"_weight_norm"MSG"Constant"MSG"Constant"MSG"Constant"MSG"ReduceMean"MSG"ReduceSum"MSG"kl_div with reduction other than none, mean, or sum. Please open a bug to "MSG"request ONNX export support for the missing reduction type."MSG"Reshape"MSG"Constant"MSG"Gather"MSG"Constant"MSG"Constant"MSG"Constant"MSG"Reshape"MSG"Constant"MSG"Mul"MSG"Constant"MSG"Add"MSG"Add"MSG"Constant"MSG"Gather"MSG"ONNX export failed. The ONNX domain and/or version to register are None."MSG"ONNX export failed. The ONNX domain and/or version are None."MSG"ONNX export failed. The ONNX domain and/or version are None."MSG"Exporting the operator "MSG" to ONNX opset version "MSG" is not supported. "MSG"Support for this operator was added in version "MSG", try exporting with this version."MSG"Please open a bug to request ONNX export support for the missing operator."MSG"You are exporting the model to ONNX while in training mode with "MSG"'MSG' parameter not specified. The model will default to inference mode export. "MSG"If you wish to export a training amenable ONNX model, specify training=TrainingMode.TRAINING or "MSG"training=TrainingMode.PRESERVE (to preserve the original model state) in torch.onnx.export()."MSG"You are exporting the model in training mode with onnx opset version {}. "MSG"Opset versions lower than opset 12 will not be able to export nodes such as"MSG"Dropout and BatchNorm correctly."MSG"prim::Constant"MSG"prim::ListConstruct"MSG"We detected that you are modifying a dictionnary that is an input to your "MSG"model. "MSG"Note that dictionaries are allowed as inputs in ONNX but they should be "MSG"handled with care. "MSG"Usages of dictionaries is not recommended, and should not be used except "MSG"for configuration use. "MSG"Also note that the order and values of the keys must remain the same. "MSG"The model seems to have string inputs/outputs. "MSG"Note that strings will not appear as inputs/outputs of the ONNX graph. "MSG"`{}'MSG'operator_export_type'MSG'operator_export_type'MSG'ONNX'MSG'keep_initializers_as_inputs=False'MSG'keep_initializers_as_inputs=False'MSG'do_constant_folding=False'MSG'training=TrainingMode.TRAIN'MSG'training=TrainingMode.PRESERVE'MSG'{}'MSG'dim_i'MSG'dims_i'MSG't support per-parameter options "MSG"(parameter groups)"MSG"strong_wolfe"MSG"only 'MSG' is supported"MSG"The epoch parameter in `scheduler.step()` was not necessary and is being "MSG"deprecated where possible. Please use `scheduler.step()` to step the "MSG"scheduler. During the deprecation, if epoch is different from None, the "MSG"closed form is used instead of the new chainable form, where available. "MSG"Please open an issue if you are unable to replicate your use case: "MSG"https://github.com/pytorch/pytorch/issues/new/choose."MSG"Please also save or load the state of the optimizer when saving or loading the scheduler."MSG"param 'MSG' is not specified "MSG"in param_groups[{}] when resuming an optimizer"MSG"_with_counter"MSG"Seems like `optimizer.step()` has been overridden after learning rate scheduler "MSG"initialization. Please, make sure to call `optimizer.step()` before "MSG"`lr_scheduler.step()`. See more details at "MSG"https://pytorch.org/docs/stable/optim.html elif self.optimizer._step_count < 1: + warnings.warn("MSG" + "MSG" + "MSG" + "MSG" + "MSG" + "MSG"_get_closed_form_lr"MSG"Expected {} lr_lambdas, but got {}"MSG"To get the last learning rate computed by the scheduler, "MSG"please use `get_last_lr()`."MSG"Expected {} lr_lambdas, but got {}"MSG"To get the last learning rate computed by the scheduler, "MSG"please use `get_last_lr()`."MSG"To get the last learning rate computed by the scheduler, "MSG"please use `get_last_lr()`."MSG"To get the last learning rate computed by the scheduler, "MSG"please use `get_last_lr()`."MSG"To get the last learning rate computed by the scheduler, "MSG"please use `get_last_lr()`."MSG"To get the last learning rate computed by the scheduler, "MSG"please use `get_last_lr()`."MSG"expected {} min_lrs, got {}"MSG"expected {} values for {}, got {}"MSG"To get the last learning rate computed by the scheduler, "MSG"please use `get_last_lr()`."MSG"Expected positive integer T_0, but got {}"MSG"Expected integer T_mult >= 1, but got {}"MSG"To get the last learning rate computed by the scheduler, "MSG"please use `get_last_lr()`."MSG"Expected non-negative epoch, but got {}"MSG"You must define either total_steps OR (epochs AND steps_per_epoch)"MSG"Expected positive integer total_steps, but got {}"MSG"Expected positive integer epochs, but got {}"MSG"Expected positive integer steps_per_epoch, but got {}"MSG"Expected float between 0 and 1 pct_start, but got {}"MSG"anneal_strategy must by one of 'MSG' or 'MSG', instead got {}"MSG"expected {} values for {}, got {}"MSG"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."MSG"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."MSG"To get the last learning rate computed by the scheduler, "MSG"please use `get_last_lr()`."MSG"Tried to step {} times. The specified number of total steps is {}"MSG""MSG"python.optimizer"MSG"params argument given to the optimizer should be "MSG"an iterable of Tensors or dicts, but got "MSG"optimizer got an empty parameter list"MSG"loaded state dict has a different number of "MSG"parameter groups"MSG"loaded state dict contains a parameter group "MSG"that doesn'MSG's group"MSG"param group must be a dict"MSG"optimizer can only optimize Tensors, "MSG"but one of the params is "MSG"can'MSG't specify a value of required optimization parameter "MSG"optimizer contains a parameter group with duplicate parameters; "MSG"in future, this will cause an error; "MSG"see github.com/pytorch/pytorch/issues/40967 for more information"MSG"some parameters appear in more than one parameter group"MSG"Invalid learning rate: {}"MSG"Invalid epsilon value: {}"MSG"Invalid momentum value: {}"MSG"Invalid weight_decay value: {}"MSG"Invalid alpha value: {}"MSG"Invalid learning rate: {}"MSG"Invalid eta values: {}, {}"MSG"Invalid learning rate: {}"MSG"Invalid momentum value: {}"MSG"Invalid weight_decay value: {}"MSG"Nesterov momentum requires a momentum and zero dampening"MSG"Invalid learning rate: {}"MSG"Invalid epsilon value: {}"MSG"Invalid beta parameter at index 0: {}"MSG"Invalid beta parameter at index 1: {}"MSG"params"MSG"Sparse params at indices {sparse_params}: SparseAdam requires dense parameter tensors"MSG"anneal_strategy must by one of 'MSG' or 'MSG', "MSG"instead got {}"MSG"anneal_epochs must be a positive integer, got {}"MSG"swa_lr must have the same length as "MSG"optimizer.param_groups: swa_lr has {}, "MSG"optimizer.param_groups has {}"MSG"To get the last learning rate computed by the scheduler, "MSG"please use `get_last_lr()`."MSG"Invalid learning rate: {}"MSG"Invalid rho value: {}"MSG"Invalid epsilon value: {}"MSG"Invalid weight_decay value: {}"MSG"Invalid learning rate: {}"MSG"Invalid epsilon value: {}"MSG"Invalid beta parameter at index 0: {}"MSG"Invalid beta parameter at index 1: {}"MSG"Invalid weight_decay value: {}"MSG"Invalid learning rate: {}"MSG"Invalid epsilon value: {}"MSG"Invalid beta parameter at index 0: {}"MSG"Invalid beta parameter at index 1: {}"MSG"Invalid weight_decay value: {}"MSG"Invalid learning rate: {}"MSG"Invalid epsilon value: {}"MSG"Invalid beta parameter at index 0: {}"MSG"Invalid beta parameter at index 1: {}"MSG"Invalid weight_decay value: {}"MSG"Invalid learning rate: {}"MSG"Invalid weight_decay value: {}"MSG"Invalid learning rate: {}"MSG"Invalid epsilon value: {}"MSG"Invalid momentum value: {}"MSG"Invalid weight_decay value: {}"MSG"Invalid alpha value: {}"MSG"Invalid learning rate: {}"MSG"Invalid eta values: {}, {}"MSG"Invalid learning rate: {}"MSG"Invalid momentum value: {}"MSG"Invalid weight_decay value: {}"MSG"Nesterov momentum requires a momentum and zero dampening"MSG"Conv and BN both must be in the same mode (train or eval)."MSG"Conv and BN both must be in the same mode (train or eval)."MSG"Cannot fuse train modules: {}"MSG"Cannot fuse eval modules: {}"MSG"Cannot fuse modules: {}"MSG"ABC"MSG"Please use quant_min and quant_max to specify the range for observers. \ + reduce_range will be deprecated in a future release of PyTorch."MSG"Default Observer only works for per_tensor_affine, \ + per_tensor_symmetric, per_channel_affine, \ + per_channel_symmetric and per_channel_float_qparams quantization scheme"MSG"Default Observer only works for qint8 and quint8 data type"MSG"Used-specified quantization range must include 0."MSG"qmin must be strictly less than qmax for user-specified quantization range."MSG"quantization range should be positive and not exceed the maximum bit range (=256)."MSG"must run observer before calling calculate_qparams.\ + Returning default scale and zero point "MSG"must run observer before calling calculate_qparams.\ + Returning default scale and zero point "MSG"min {} should be less than max {}"MSG"min {} should be less than max {}"MSG"Cannot reduce range for symmetric \ + quantization for quint8"MSG"min_val={}, max_val={}"MSG"min {} should be less than max {}"MSG"Cannot reduce range for symmetric quantization for quint8"MSG"min_val={}, max_val={}"MSG"L2"MSG"Only L2 norms are currently supported"MSG"L2"MSG"bins mistmatch"MSG"inf"MSG"L2"MSG"must run observer before calling calculate_qparams.\ + Returning default scale and zero point "MSG"The number of bins in histogram should be equal to the number of bins "MSG"supplied while making this observer"MSG""MSG"calculate_qparams should not be called for PlaceholderObserver"MSG"tensor_val"MSG"calculate_qparams should not be called for RecordingObserver"MSG""MSG"calculate_qparams should not be called for NoopObserver"MSG"torch.quantization.observer.PerChannelMinMaxObserver"MSG"torch.quantization.observer.MovingAveragePerChannelMinMaxObserver"MSG"Missing keys for observer {} in state_dict"MSG"Unexpected keys for observer {} in state_dict"MSG"QConfig received observer instance, please pass observer class instead. "MSG"Use MyObserver.with_args(x=1) to override arguments to constructor if needed"MSG"QConfigDynamic received observer instance, please pass observer class instead. "MSG"Use MyObserver.with_args(x=1) to override arguments to constructor if needed"MSG"add_observer_ only works with cpu or single-device CUDA modules, "MSG"but got devices {}"MSG"None of the submodule got qconfig applied. Make sure you "MSG"passed correct configuration through `qconfig_dict` or "MSG"by assigning the `.qconfig` attribute directly on submodules"MSG"qconfig"MSG"Don'MSG't, or both are nan and "MSG"equal_nan is False"MSG"Comparing"MSG"{0} and {1} gives a "MSG"difference of {2}, but the allowed difference "MSG"with rtol={3} and atol={4} is "MSG"only {5}!"MSG" the real part "MSG" the imaginary part "MSG" "MSG"rtol and atol must both be specified or both be unspecified"MSG"eq"MSG"ge"MSG"gt"MSG"le"MSG"lt"MSG"ne"MSG"add"MSG"div"MSG"mul"MSG"__eq__"MSG"__ge__"MSG"__gt__"MSG"__le__"MSG"__lt__"MSG"__ne__"MSG"__add__"MSG"__div__"MSG"__mul__"MSG"_convolution"MSG"_convolution"MSG"_convolution_nogroup"MSG"conv1d"MSG"conv2d"MSG"conv3d"MSG"conv_tbc"MSG"conv_transpose1d"MSG"conv_transpose2d"MSG"conv_transpose3d"MSG"convolution"MSG"cudnn_convolution"MSG"cudnn_convolution_transpose"MSG"cudnn_convolution"MSG"cudnn_convolution_transpose"MSG"cudnn_convolution"MSG"cudnn_convolution_transpose"MSG"prelu"MSG"addmm"MSG"addmv"MSG"addr"MSG"matmul"MSG"mm"MSG"mv"MSG"chain_matmul"MSG"addbmm"MSG"baddbmm"MSG"bmm"MSG"lstm_cell"MSG"gru_cell"MSG"rnn_tanh_cell"MSG"rnn_relu_cell"MSG"acos"MSG"asin"MSG"cosh"MSG"erfinv"MSG"exp"MSG"expm1"MSG"log"MSG"log10"MSG"log2"MSG"log1p"MSG"reciprocal"MSG"rsqrt"MSG"sinh"MSG"tan"MSG"pow"MSG"pow"MSG"softmax"MSG"log_softmax"MSG"layer_norm"MSG"group_norm"MSG"norm"MSG"norm"MSG"dim"MSG"norm"MSG"p"MSG"norm"MSG"p"MSG"dim"MSG"cosine_similarity"MSG"poisson_nll_loss"MSG"cosine_embedding_loss"MSG"hinge_embedding_loss"MSG"kl_div"MSG"margin_ranking_loss"MSG"triplet_margin_loss"MSG"binary_cross_entropy_with_logits"MSG"cumprod"MSG"cumsum"MSG"dist"MSG"pdist"MSG"cdist"MSG"prod"MSG"prod"MSG"renorm"MSG"sum"MSG"sum"MSG"addcdiv"MSG"addcmul"MSG"atan2"MSG"bilinear"MSG"cross"MSG"cat"MSG"dot"MSG"equal"MSG"index_put"MSG"index_put"MSG"stack"MSG"tensordot"MSG"linear"MSG"softplus"MSG"gelu"MSG"nll_loss"MSG"nll_loss2d"MSG"l1_loss"MSG"smooth_l1_loss"MSG"mse_loss"MSG"multilabel_margin_loss"MSG"soft_margin_loss"MSG"multi_margin_loss"MSG"__matmul__"MSG"__pow__"MSG"binary_cross_entropy"MSG"cuda:0"MSG"cuda:{}"MSG"this decorator only support function with signature (self, device) or (self, device, dtype)"MSG"."MSG"_"MSG"_"MSG"_"MSG"_"MSG"Skipped!"MSG"Redefinition of test {0}"MSG"op_list"MSG"_base"MSG"same device cannot appear in except_for and only_for"MSG"Couldn'MSG'{0}'MSG't run on {0}"MSG"precisionOverride not given a dtype : precision dict!"MSG"precisionOverride given unknown dtype {0}"MSG"When one dtype variant is a tuple or list, "MSG"all dtype variants must be. "MSG"Received non-list non-tuple dtype {0}"MSG"Unknown dtype in {0}"MSG"Unknown dtype in {0}"MSG"dtypes redefinition for {0}"MSG"'MSG'"MSG"'MSG'"MSG"PyTorch compiled without Lapack"MSG"PyTorch is built without MKL support"MSG"no MAGMA library detected"MSG"test doesn'MSG't currently work on the CUDA stack"MSG"cuDNN not available"MSG"cuDNN version {0} is available but {1} required"MSG"backend_unavailable"MSG"Skipped because distributed backend is not available."MSG"small_worldsize"MSG"Skipped due to small world size."MSG"no_cuda"MSG"CUDA is not available."MSG"multi-gpu"MSG"Need at least 2 CUDA devices"MSG"nccl"MSG"c10d not compiled with NCCL support"MSG"skipIfRocm"MSG"Test skipped for ROCm"MSG"no_cuda"MSG"WORLD_SIZE"MSG"Need at least {} CUDA devices"MSG"WORLD_SIZE"MSG"multi-gpu"MSG"multi-gpu"MSG"BACKEND"MSG"mpi"MSG"WORLD_SIZE"MSG"small_worldsize"MSG"Need at least {} CUDA devices"MSG"multi-gpu"MSG"nccl"MSG"Need at least {} CUDA devices"MSG"multi-gpu"MSG"Need at least {} CUDA devices"MSG"multi-gpu"MSG"c10d was not compiled with the Gloo backend"MSG"c10d was not compiled with the NCCL backend"MSG"Requires NCCL version greater than or equal to: {}, found: {}, reason: {}"MSG"c10d was not compiled with the NCCL backend"MSG"c10d was not compiled with the MPI backend"MSG"Test skipped for ROCm"MSG"This unit test case is not supportted on Windows platform"MSG"test_ddp_uneven_inputs"MSG"127.0.0.1"MSG"TEMP_DIR"MSG"barrier"MSG"test_dir"MSG"init_dir"MSG"INIT_METHOD"MSG"INIT_METHOD"MSG"shared_init_file"MSG"."MSG"fork"MSG"spawn"MSG"Process {} terminated with exit code {}, terminating remaining processes."MSG"Timing out after {} seconds and killing subprocesses."MSG"Processes {} exited with error code {}"MSG" "MSG"Expect process {} exit code to match Process 0 exit code of {}, but got {}"MSG"Expected zero exit code but got {}"MSG"_"MSG"2d_1d"MSG"1d_2d"MSG"2d_2d"MSG"3d_1d"MSG"3d_2d"MSG"1d_3d"MSG"2d_3d"MSG"4d_4d"MSG"4d_1d"MSG"1d_4d"MSG"n=2"MSG"n=3"MSG"n=1"MSG"n=0"MSG"n=-1"MSG"n=-3"MSG"n=-2"MSG"single_matrix"MSG"batch_of_matrices"MSG"p=1"MSG"p=2"MSG"p=3"MSG"p=5"MSG"User provided tensor is real for a test that runs with complex dtype, "MSG"which is not supported for now"MSG"_"MSG"relu"MSG"1d"MSG"{}: Specify {} by a value, a function to generate it, or it'MSG't currently work on the ROCm stack"MSG"."MSG"PyTorch was compiled without numpy support"MSG"Cannot import `caffe2.python.core`"MSG"test require SciPy, but SciPy not found"MSG"test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test"MSG"ignore"MSG"pytorch_ci"MSG"dev"MSG"debug"MSG"pytorch_ci"MSG"PYTORCH_RUN_DISABLED_TESTS"MSG"0"MSG"1"MSG"Couldn'MSG'{}'MSG'a'MSG', r'MSG'{}'MSG' grads cancel each other. Received "MSG"gradient {grad}"MSG"The grad for any non-ddp parameter shouldn'MSG'win32'MSG'win32'MSG'win32'MSG'nccl'MSG'no_cuda'MSG'MASTER_PORT'MSG'env://'MSG't supports GPU barrier"MSG"mpi"MSG"MPI doesn'MSG't supports GPU barrier"MSG"nccl"MSG"NCCL does not support CPU barrier"MSG"nccl"MSG"NCCL does not support CPU barrier"MSG"nccl"MSG"NCCL does not support CPU barrier"MSG"mpi"MSG"MPI doesn'MSG't support broadcast multigpu"MSG"nccl"MSG"CUDA all_reduce multigpu skipped for NCCL"MSG"nccl"MSG"Only Nccl backend supports reduce multigpu"MSG"nccl"MSG"Only Nccl backend supports allgather multigpu"MSG"WORLD_SIZE"MSG"file://"MSG"nccl"MSG"nccl does not support DDP on CPU models"MSG"nccl"MSG"nccl does not support DDP on CPU models"MSG"Only Nccl & Gloo backend support DistributedDataParallel"MSG"nccl"MSG"gloo"MSG"Only NCCL and GLOO backend support DistributedDataParallel"MSG"WORLD_SIZE"MSG"WORLD_SIZE"MSG"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}"MSG"Only Nccl & Gloo backend support DistributedDataParallel"MSG"Only Nccl & Gloo backend support DistributedDataParallel"MSG"Only Nccl & Gloo backend support DistributedDataParallel"MSG"Only Nccl & Gloo backend support DistributedDataParallel"MSG"Only Nccl & Gloo backend support DistributedDataParallel"MSG"Only Nccl & Gloo backend support DistributedDataParallel"MSG"Only Nccl & Gloo backend support DistributedDataParallel"MSG"Reduction fn {reduction_fn} must specify dst!"MSG"nccl"MSG"nccl"MSG"nccl"MSG"nccl"MSG"nccl"MSG"nccl"MSG"WORLD_SIZE"MSG"nccl"MSG"nccl"MSG"nccl"MSG"gloo"MSG"Only NCCL and GLOO backend support DistributedDataParallel"MSG"WORLD_SIZE"MSG"nccl"MSG"gloo"MSG"WORLD_SIZE"MSG"BACKEND"MSG"gloo"MSG"nccl"MSG"NCCL does not support gather"MSG"Can'MSG'cpu2'MSG't call localValue\(\) on user "MSG"WorkerInfo\(id={self.rank}, name={worker_name(self.rank)}\). "MSG"Call it on owner WorkerInfo\(id={next_rank}, name={worker_name(next_rank)}\)"MSG"rpc_sync + return any([e.name.startswith(expected_name) for e in events]) dst = worker_name((self.rank + 1) % self.world_size) + rref = rpc.remote(dst, torch.add, args=(torch.ones(2), 1)) with torch.autograd.profiler.profile() as p: + t = rref._get_type() self.assertTrue(launched_rpc(p.function_events)) + self.assertEqual(t, type(torch.ones(2))) with torch.autograd.profiler.profile() as p: + for _ in range(10): + t = rref._get_type() self.assertFalse(launched_rpc(p.function_events)) + self.assertEqual(t, type(torch.ones(2))) rref = rpc.remote(dst, MyClass, args=(0,)) + self.assertEqual(rref._get_type(), MyClass) @dist_init + def test_rref_type_with_error(self): + dst = worker_name((self.rank + 1) % self.world_size) + rref = rpc.remote(dst, raise_func) with self.assertRaisesRegex(ValueError, "MSG"): + rref._get_type() @dist_init + def test_rref_type_owner(self): + rref = RRef(torch.ones(2) + 1) + self.assertEqual(rref._get_type(), type(torch.ones(2))) rref = RRef(MyClass(0)) + self.assertEqual(rref._get_type(), MyClass) @staticmethod + def _slow_add(x, y): + time.sleep(1) + return x + y @dist_init + def test_rref_type_slow_init(self): + dst = worker_name((self.rank + 1) % self.world_size) + rref = rpc.remote(dst, RpcTest._slow_add, args=(torch.ones(2), 1)) + self.assertEqual(rref._get_type(), type(torch.ones(2))) @dist_init + def test_owner_equality(self): + a = RRef(40) + b = RRef(50) other_rank = (self.rank + 1) % self.world_size + other_a = rpc.remote( + worker_name(other_rank), torch.add, args=(torch.ones(1), 1) + ) + other_b = rpc.remote( + worker_name(other_rank), torch.add, args=(torch.ones(1), 1) + ) + other_a.to_here() + other_b.to_here() self.assertNotEqual(a.owner(), 23) + self.assertEqual(other_a.owner(), other_b.owner()) + self.assertNotEqual(a.owner(), other_a.owner()) + self.assertEqual(other_a.owner(), other_a.owner()) + self.assertEqual(other_a.owner(), other_b.owner()) + self.assertEqual(a.owner(), a.owner()) + self.assertEqual(a.owner(), b.owner()) + self.assertEqual(a.owner(), rpc.get_worker_info()) + x = dict() + x[a.owner()] = a + x[other_a.owner()] = other_a + self.assertEqual(x[a.owner()], a) + self.assertEqual(x[b.owner()], a) + self.assertEqual(x[other_a.owner()], other_a) + self.assertEqual(x[other_b.owner()], other_a) + self.assertEqual(len(x), 2) @dist_init + def test_pass_local_rrefs(self): + n = self.rank + 1 + dst_rank = n % self.world_size + dst_worker = worker_name(dst_rank) rref = RRef(40) + self.assertEqual( + rpc.rpc_sync(dst_worker, add_rref_to_value, args=(rref, 50)), 90 + ) + self.assertEqual( + rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 50)).wait(), 90 + ) + self.assertEqual( + rpc.remote(dst_worker, add_rref_to_value, args=(rref, 50)).to_here(), 90 + ) @dist_init + def test_remote_same_worker(self): + n = self.rank + 1 + dst_rank = n % self.world_size + rref_a = rpc.remote( + worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 2) + ) + rref_b = rpc.remote( + worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1) + ) + rref_c = rpc.remote( + worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b) + ) + self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4) @dist_init(setup_rpc=True) + def test_call_method_on_rref(self): + vals = [10, 2, 5, 7] + dst_rank = (self.rank + 1) % self.world_size + dst_worker = worker_name(dst_rank) rref = rpc.remote(dst_worker, MyClass, args=(vals[0],)) rpc.rpc_sync( + rref.owner(), + _call_method_on_rref, + args=(MyClass.increment_value, rref, vals[1]), + ) + rpc.rpc_async( + rref.owner(), + _call_method_on_rref, + args=(MyClass.increment_value, rref, vals[2]), + ).wait() + rpc.remote( + rref.owner(), + _call_method_on_rref, + args=(MyClass.increment_value, rref, vals[3]), + ).to_here() result = rpc.rpc_sync( + dst_worker, _call_method_on_rref, args=(MyClass.get_value, rref) + ) self.assertEqual(result, sum(vals)) + + + + + @mock.patch.object(torch.distributed.rpc.api, "MSG") + def _test_rref_leak(self, _mock_delete_all_user_and_unforked_owner_rrefs, ignore_leak): + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) initialize_pg(self.init_method, self.rank, self.world_size) + dist.barrier() rref = rpc.remote( + worker_name((self.rank + 1) % self.world_size), + torch.add, + args=(torch.ones(2, 2), 1), + ) import torch.distributed.rpc.api as api if ignore_leak: + api._ignore_rref_leak = True + rpc.shutdown(graceful=True) + else: + api._ignore_rref_leak = False + with self.assertRaisesRegex(RuntimeError, "MSG"): + rpc.shutdown(graceful=True) @dist_init(setup_rpc=False) + def test_rref_leak(self): + self._test_rref_leak(ignore_leak=False) @dist_init(setup_rpc=False) + def test_ignore_rref_leak(self): + self._test_rref_leak(ignore_leak=True) @dist_init + def test_rref_str(self): + rref1 = RRef(self.rank) + id_class = "MSG" + self.assertEqual( + "MSG".format(id_class, self.rank), rref1.__str__() + ) dst_rank = (self.rank + 1) % self.world_size + rref2 = rpc.remote( + worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1) + ) + self.assertEqual( + rref2.__str__(), + "MSG".format( + id_class, self.rank + ), + ) @dist_init + def test_rref_get_future(self): + if self.rank == 0: + rref = rpc.remote(worker_name(1), torch.add, args=(1, 1)) + rref.to_here() + fut = rref._get_future() + self.assertIsInstance(fut, torch._C.Future) rref = rpc.remote(worker_name(1), foo_add, args=()) + rref.to_here() + fut = rref._get_future() + self.assertIsInstance(fut, torch._C.Future) rref = rpc.remote(worker_name(1), my_script_func, args=(torch.tensor(1), )) + rref.to_here() + fut = rref._get_future() + self.assertIsInstance(fut, torch._C.Future) + @dist_init + def test_rref_context_debug_info(self): + initialize_pg(self.init_method, self.rank, self.world_size) rref1 = RRef(self.rank) info = _rref_context_get_debug_info() + self.assertIn("MSG", info) + self.assertIn("MSG", info) + self.assertEqual(0, int(info["MSG"])) + self.assertEqual(0, int(info["MSG"])) + dist.barrier() dst_rank = (self.rank + 1) % self.world_size + rpc.rpc_sync(worker_name(dst_rank), set_global_rref, args=(rref1,)) wait_until_pending_futures_and_users_flushed() + dist.barrier() info = _rref_context_get_debug_info() + self.assertIn("MSG", info) + self.assertEqual(1, int(info["MSG"])) + self.assertEqual(0, int(info["MSG"])) + dist.barrier() rpc.rpc_sync(worker_name(dst_rank), clear_global_rref) rref2 = rpc.remote( + worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1) + ) + rref3 = rpc.remote( + worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1) + ) + rref2.to_here() + rref3.to_here() wait_until_pending_futures_and_users_flushed() + dist.barrier() info = _rref_context_get_debug_info() + self.assertIn("MSG", info) + self.assertEqual(2, int(info["MSG"])) + self.assertEqual(0, int(info["MSG"])) dist.barrier() @dist_init + def test_disable_gil_profiling(self): + dst_rank = (self.rank + 1) % self.world_size + rpc.rpc_sync( + worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1)) + ) + info = rpc.api._get_current_rpc_agent().get_debug_info() + self.assertRaises(KeyError, lambda: info["MSG"]) + rpc.enable_gil_profiling(True) + rpc.rpc_sync( + worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1)) + ) + info = rpc.api._get_current_rpc_agent().get_debug_info() + self.assertIn("MSG", info) @dist_init(setup_rpc=False) + def test_local_shutdown(self): + rpc.init_rpc( + name="MSG" % self.rank, + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + rpc.shutdown(graceful=False) @dist_init + def test_debug_info(self): + import torch.distributed.autograd as dist_autograd info = _get_debug_info() + rref_info = _rref_context_get_debug_info() + agent_info = rpc.api._get_current_rpc_agent().get_debug_info() + autograd_info = dist_autograd._get_debug_info() + common_keys = rref_info.keys() & agent_info.keys() & autograd_info.keys() + self.assertEqual(0, len(common_keys)) + expected = {} + expected.update(rref_info) + expected.update(agent_info) + expected.update(autograd_info) + for key in expected.keys(): + self.assertIn(key, info.keys()) for key in info.keys(): + self.assertIn(key, expected.keys()) @dist_init(setup_rpc=False) + @unittest.skipIf( + IS_MACOS, + "MSG", + ) + def test_handle_send_exceptions(self): + rpc.init_rpc( + name="MSG" % self.rank, + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + rpc._set_rpc_timeout(10) + initialize_pg(self.init_method, self.rank, self.world_size) + dist.barrier() + if self.rank == 1: + dst_rank = (self.rank + 1) % self.world_size + dst_worker = worker_name(dst_rank) + error_str = self.get_shutdown_error_regex() + wait_until_node_failure(dst_rank, error_str) + fut = rpc.rpc_async(dst_worker, torch.add, args=(torch.ones(1), 3)) + with self.assertRaisesRegex(RuntimeError, error_str): + fut.wait() + rpc.shutdown(graceful=False) @dist_init + def test_deadlock(self): + if self.rank == 1: + dst1 = worker_name((self.rank + 1) % self.world_size) + x = torch.ones(2) + y = torch.ones(2) + rpc.rpc_async(dst1, RpcTest._slow_add, args=(x, y), timeout=15).wait() dist_initialized = dist.is_initialized() + if not dist_initialized: + dist.init_process_group( + backend="MSG", + init_method=self.init_method, + rank=self.rank, + world_size=self.world_size, + ) @dist_init(setup_rpc=False) + def test_local_shutdown_with_rpc(self): + rpc.init_rpc( + name="MSG" % self.rank, + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + n = self.rank + 1 + dst_rank = n % self.world_size + rpc.rpc_sync( + worker_name(dst_rank), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)), + ) + initialize_pg(self.init_method, self.rank, self.world_size) + dist.barrier() + rpc.shutdown(graceful=False) @dist_init(setup_rpc=False) + def test_set_and_get_default_rpc_timeout(self): + timeout = 0.5 rpc_backend_options = self.rpc_backend_options + rpc_backend_options.rpc_timeout = timeout rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) + set_timeout = rpc.get_rpc_timeout() + self.assertEqual(timeout, set_timeout) + rpc.shutdown() @dist_init + def test_default_timeout_used(self): + dst_rank = (self.rank + 1) % self.world_size + rpc._set_rpc_timeout(0.001) + futs = [ + rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=()) + for _ in range(10) + ] + expected_error = self.get_timeout_error_regex() + for fut in futs: + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() rpc._set_rpc_timeout(200) + fut1 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,)) + rpc._set_rpc_timeout(0.001) + fut2 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,)) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut2.wait() + fut1.wait() rpc._set_rpc_timeout(0) + rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=()).wait() rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) @dist_init + def test_rpc_timeouts(self): + dst_rank = (self.rank + 1) % self.world_size + dst_worker = worker_name(dst_rank) + timeout = 0.1 + expected_error = self.get_timeout_error_regex() + fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=timeout) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() rpc.rpc_async(dst_worker, my_sleep_func, args=(1,)).wait() with self.assertRaisesRegex(RuntimeError, expected_error): + rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=timeout) rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,)) rpc._set_rpc_timeout(0.001) + fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,)) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() + with self.assertRaisesRegex(RuntimeError, expected_error): + rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,)) rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=5).wait() + rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=5) + rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=0).wait() + rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=0) + rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) def test_dist_init_decorator(self): + @dist_init(setup_rpc=False) + def test_func(self): + return "MSG" self.assertEqual(test_func(self), "MSG") @dist_init + def test_func(self): + return "MSG" self.assertEqual(test_func(self), "MSG") def test_use_rpc_pickler(self): + class TestPickler: + pass test_pickler = TestPickler() + with _use_rpc_pickler(test_pickler): + self.assertTrue(torch.distributed.rpc.api._default_pickler is test_pickler) + self.assertTrue( + torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler + ) @dist_init + def test_function_not_on_callee(self): + this_module = sys.modules[__name__] + caller_worker = "MSG" + callee_worker = "MSG" if self.rank == 1: + delattr(this_module, "MSG") + rpc.rpc_sync(caller_worker, set_value, args=(self.rank,)) if self.rank == 0: + wait_for_value_future() + self.assertTrue(hasattr(this_module, "MSG")) + with self.assertRaisesRegex( + AttributeError, "MSG" + ): + rpc.rpc_sync(callee_worker, foo_add, args=()) @dist_init + def test_non_garbage_collected_user_rref_due_to_local_circular_dependency(self): + dst_worker_name = worker_name((self.rank + 1) % self.world_size) a = MyClass(1) + b = MyClass(2) a.other = b + b.other = a n = self.rank + a.rref = rpc.remote( + dst_worker_name, + torch.add, + args=(torch.ones(n, n), 2) + ) @dist_init(setup_rpc=False) + def test_use_rref_after_shutdown(self): + rpc.init_rpc( + name="MSG" % self.rank, + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + n = self.rank + 1 + dst_rank = n % self.world_size + rref = rpc.remote( + worker_name(dst_rank), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)), + ) + rpc.shutdown(graceful=True) with self.assertRaisesRegex( + RuntimeError, "MSG" + ): + rref.to_here() with self.assertRaisesRegex( + RuntimeError, "MSG" + ): + import torch.distributed.rpc.internal as internal + internal.serialize(rref) @staticmethod + def _return_gpu_tensor(): + return torch.rand(3, 3).cuda(0) @staticmethod + def _return_gpu_tensor_list(): + return [torch.rand(3, 3).cuda(0), torch.rand(3, 3).cuda(1)] @staticmethod + def _gpu_tensor_list_arg(tensor_list): + return torch.rand(3, 3) @skip_if_lt_x_gpu(2) + @dist_init + def test_cuda(self): + dst = worker_name((self.rank + 1) % self.world_size) + t1 = torch.rand(3, 3).cuda(0) + t2 = torch.rand(3, 3).cuda(1) + t3 = torch.rand(3, 3) with self.assertRaisesRegex(RuntimeError, "MSG"): + rpc.rpc_sync(dst, torch.add, args=(t1, t2)) with self.assertRaisesRegex(RuntimeError, "MSG"): + rpc.rpc_sync(dst, torch.add, args=(t1, t3)) with self.assertRaisesRegex(RuntimeError, "MSG"): + rpc.rpc_sync(dst, RpcTest._gpu_tensor_list_arg, args=([t1, t2])) with self.assertRaisesRegex(RuntimeError, "MSG"): + rpc.rpc_sync(dst, RpcTest._return_gpu_tensor, args=()) with self.assertRaisesRegex(RuntimeError, "MSG"): + rpc.rpc_sync(dst, RpcTest._return_gpu_tensor_list, args=()) with self.assertRaisesRegex(RuntimeError, "MSG"): + rpc.rpc_sync(worker_name(self.rank), torch.add, args=(t1, t2)) def _create_rref(self): + owner_rank = (self.rank + 2) % self.world_size + return rpc.remote( + worker_name(owner_rank), + torch.add, + args=(torch.zeros(2, 2), 1) + ) @dist_init + def test_user_rrefs_confirmed(self): + dst_rank = (self.rank + 1) % self.world_size + rref = self._create_rref() + ret = rpc.rpc_sync( + worker_name(dst_rank), + check_rref_confirmed, + args=(rref,) + ) + self.assertEqual(ret, True) @dist_init + def test_user_rrefs_confirmed_remote(self): + dst_rank = (self.rank + 1) % self.world_size + rref = self._create_rref() + ret_rref = rpc.remote( + worker_name(dst_rank), + check_rref_confirmed, + args=(rref,) + ) + self.assertEqual(ret_rref.to_here(), True) @dist_init + def test_rref_py_pickle_not_supported(self): + local_rref = RRef(35) + with TemporaryFileName() as fname: + with self.assertRaisesRegex(RuntimeError, "MSG"): + torch.save(local_rref, fname) @dist_init + def test_remote_throw(self): + rref = rpc.remote(worker_name((self.rank + 1) % self.world_size), + raise_or_inc, + args=(torch.ones(2),)) + with self.assertRaisesRegex(Exception, "MSG"): + rref.to_here() @dist_init + def test_non_cont_tensors(self): + if self.rank == 0: + t = torch.rand(5, 5) + t_view = t.narrow(1, 2, 2) + self.assertFalse(t_view.is_contiguous()) + t_cont = t_view.contiguous() + self.assertTrue(t_cont.is_contiguous()) + self.assertEqual(t_view, t_cont) next_rank = (self.rank + 1) % self.world_size + t_ret = rpc.rpc_sync(worker_name(next_rank), non_cont_test, args=(t_view, t_cont)) self.assertEqual(t_view, t_ret) + self.assertFalse(t_ret.is_contiguous()) @dist_init + def test_callback_simple(self): + set_by_cb = concurrent.futures.Future() + n = self.rank + 1 def callback(fut): + ret = fut.wait() + self.assertEqual(ret, torch.ones(n, n) * 2) + set_by_cb.set_result(ret.clone() + 1) fut = rpc.rpc_async( + worker_name(n % self.world_size), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)) + ) fut.then(callback) self.assertEqual(fut.wait(), torch.ones(n, n) * 2) + self.assertEqual(set_by_cb.result(), torch.ones(n, n) * 2 + 1) + self.assertEqual(fut.wait(), torch.ones(n, n) * 2) @dist_init + def test_callback_wrong_arg_num(self): + set_by_cb = concurrent.futures.Future() + n = self.rank + 1 fut = rpc.rpc_async( + worker_name(n % self.world_size), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)) + ) cb_fut = fut.then(my_function) self.assertEqual(fut.wait(), torch.ones(n, n) * 2) with self.assertRaisesRegex( + RuntimeError, + "MSG" + ): + cb_fut.wait() @dist_init + def test_callback_wrong_arg_type(self): + dst = worker_name((self.rank + 1) % self.world_size) fut0 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1)) + fut1 = fut0.then(lambda x: x + 1) with self.assertRaisesRegex( + RuntimeError, + "MSG" + ): + fut1.wait() @dist_init + def test_callback_multi(self): + num_cbs = 10 + n = self.rank + 1 def callback(idx, fut): + ret = fut.wait() + self.assertEqual(ret, torch.ones(n, n) * 2) + return ret + idx fut = rpc.rpc_async( + worker_name(n % self.world_size), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)) + ) cb_futs = [] + for idx in range(num_cbs): + cb_futs.append(fut.then(partial(callback, idx))) self.assertEqual(fut.wait(), torch.ones(n, n) * 2) for idx in range(num_cbs): + self.assertEqual( + cb_futs[idx].wait(), + torch.ones(n, n) * 2 + idx + ) self.assertEqual(fut.wait(), torch.ones(n, n) * 2) @dist_init + def test_callback_chain(self): + n = self.rank + 1 + dst = worker_name(n % self.world_size) def callback(fut): + return fut.wait() + 1 fut = rpc.rpc_async( + worker_name(n % self.world_size), + torch.add, + args=(torch.ones(n, n), 1) + ) num_cbs = 20 + for _ in range(num_cbs): + fut = fut.then(callback) self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs) @dist_init + def test_callback_in_rpc(self): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) ret = rpc.rpc_sync( + dst1, + add_use_future_cb, + args=(dst2, torch.ones(2, 2), 1, 2) + ) + self.assertEqual(ret, torch.ones(2, 2) + 1 + 2) @dist_init + def test_callback_with_ret(self): + dst = worker_name((self.rank + 1) % self.world_size) def callback(fut0): + fut2 = rpc.rpc_async( + dst, + torch.add, + args=(fut0.wait(), 1) + ).then(lambda fut1: fut1.wait() + 1) return fut2.wait() fut3 = rpc.rpc_async( + dst, + torch.add, + args=(torch.ones(2, 2), 1) + ).then(callback) self.assertEqual(fut3.wait(), torch.ones(2, 2) + 3) @dist_init + def test_callback_with_error(self): + dst = worker_name((self.rank + 1) % self.world_size) def callback(fut0): + with self.assertRaisesRegex(ValueError, "MSG"): + fut0.wait() + raise RuntimeError("MSG") fut1 = rpc.rpc_async(dst, raise_func).then(callback) + with self.assertRaisesRegex(RuntimeError, "MSG"): + fut1.wait() @dist_init + def test_callback_none(self): + dst = worker_name((self.rank + 1) % self.world_size) + with self.assertRaisesRegex( + TypeError, + "MSG" + ): + rpc.rpc_async(dst, raise_func).then(None) @dist_init + def test_mark_future_twice(self): + fut = rpc.rpc_async( + worker_name((self.rank + 1) % self.world_size), + torch.add, + args=(torch.zeros(2, 2), 1) + ) + self.assertEqual(fut.wait(), torch.zeros(2, 2) + 1) + with self.assertRaisesRegex( + RuntimeError, + "MSG" + ): + fut.set_result(1) @dist_init + def test_pickle_future(self): + fut = torch.futures.Future() + errMsg = "MSG" dst = worker_name((self.rank + 1) % self.world_size) + with TemporaryFileName() as fname: + with self.assertRaisesRegex(RuntimeError, errMsg): + rpc.rpc_sync(dst, fail_on_fut, args=(fut,)) with TemporaryFileName() as fname: + with self.assertRaisesRegex(RuntimeError, errMsg): + rpc.rpc_async(dst, fail_on_fut, args=(fut,)) with TemporaryFileName() as fname: + with self.assertRaisesRegex(RuntimeError, errMsg): + rpc.remote(dst, fail_on_fut, args=(fut,)) @dist_init + def test_future_done(self): + dst = worker_name((self.rank + 1) % self.world_size) + fut = rpc.rpc_async(dst, torch.add, args=(torch.zeros(2), 1)) + fut.wait() + self.assertTrue(fut.done()) @dist_init + def test_future_done_exception(self): + dst = worker_name((self.rank + 1) % self.world_size) + fut = rpc.rpc_async(dst, raise_func) + with self.assertRaisesRegex(ValueError, "MSG"): + fut.wait() + self.assertTrue(fut.done()) def _test_future_cb(self, func): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) ret = rpc.rpc_sync( + dst1, + func, + args=(dst2, torch.ones(2, 2), 1, 2) + ) + self.assertEqual(ret, torch.ones(2, 2) + 1 + 2) @dist_init + def test_future_in_rpc(self): + self._test_future_cb(add_use_future_set_result) @dist_init + def test_future_nested_callback(self): + self._test_future_cb(add_use_future_nested_cb) def _run_func_in_mode(self, to, fn, mode, args=None, kwargs=None): + if mode == RPCExecMode.SYNC: + return rpc.rpc_sync(to, fn, args=args, kwargs=kwargs) + elif mode == RPCExecMode.ASYNC: + return rpc.rpc_async(to, fn, args=args, kwargs=kwargs).wait() + elif mode == RPCExecMode.REMOTE: + return rpc.remote(to, fn, args=args, kwargs=kwargs).to_here() def _test_async_function_raise(self, mode): + with self.assertRaisesRegex(RuntimeError, "MSG"): + self._run_func_in_mode( + worker_name((self.rank + 1) % self.world_size), + async_raise_func, + mode + ) @dist_init + def test_async_function_raise(self): + self._test_async_function_raise(RPCExecMode.SYNC) @dist_init + def test_async_function_raise_async(self): + self._test_async_function_raise(RPCExecMode.ASYNC) @dist_init + def test_async_function_raise_remote(self): + self._test_async_function_raise(RPCExecMode.REMOTE) def _test_async_function_wrong_return_type(self, mode): + errMsg = ( + "MSG" + "MSG" + ) + with self.assertRaisesRegex(RuntimeError, errMsg): + self._run_func_in_mode( + worker_name((self.rank + 1) % self.world_size), + async_wrong_type, + mode + ) @dist_init + def test_async_function_wrong_return_type(self): + self._test_async_function_wrong_return_type(RPCExecMode.SYNC) @dist_init + def test_async_function_wrong_return_type_async(self): + self._test_async_function_wrong_return_type(RPCExecMode.ASYNC) @dist_init + def test_async_function_wrong_return_type_remote(self): + self._test_async_function_wrong_return_type(RPCExecMode.REMOTE) @dist_init + def test_async_function_simple(self): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) ret = rpc.rpc_sync(dst1, async_add, args=(dst2, torch.ones(2, 2), 1)) + self.assertEqual(ret, torch.ones(2, 2) + 1) def _test_async_function(self, fn, mode=RPCExecMode.SYNC): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) args = (dst2, torch.ones(2, 2), 1, 2) + ret = self._run_func_in_mode(dst1, fn, mode, args=args) + self.assertEqual(ret, torch.ones(2, 2) + 3) @dist_init + def test_async_function_with_future_ctor(self): + self._test_async_function(async_add_with_future_ctor) @dist_init + def test_async_function_with_future_ctor_remote(self): + self._test_async_function( + async_add_with_future_ctor, + RPCExecMode.REMOTE + ) @dist_init + def test_async_function_chained(self): + self._test_async_function(async_add_chained) @dist_init + def test_async_function_chained_remote(self): + self._test_async_function(async_add_chained, RPCExecMode.REMOTE) @dist_init + def test_async_function_nested(self): + self._test_async_function(async_add_nested) @dist_init + def test_async_function_nested_remote(self): + self._test_async_function(async_add_nested, RPCExecMode.REMOTE) @dist_init + def test_async_static_method(self): + self._test_async_function(AsyncExecutionClass.static_async_add) @dist_init + def test_async_static_method_remote(self): + self._test_async_function( + AsyncExecutionClass.static_async_add, + RPCExecMode.REMOTE + ) @dist_init + def test_async_class_method(self): + self._test_async_function(AsyncExecutionClass.class_async_add) @dist_init + def test_async_class_method_remote(self): + self._test_async_function( + AsyncExecutionClass.class_async_add, + RPCExecMode.REMOTE + ) def _test_test_async_class_rref_proxy(self, mode=RPCExecMode.SYNC): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) + rref = rpc.remote(dst1, AsyncExecutionClass) x = torch.ones(2, 2) + y = torch.ones(2, 2) + 1 + if mode == RPCExecMode.SYNC: + ret = rref.rpc_sync().static_async_add(dst2, x, x, y) + ret += rref.rpc_sync().class_async_add(dst2, x, x, y) + ret += rref.rpc_sync().bound_async_add(dst2, x, x, y) + elif mode == RPCExecMode.ASYNC: + ret = rref.rpc_async().static_async_add(dst2, x, x, y).wait() + ret += rref.rpc_async().class_async_add(dst2, x, x, y).wait() + ret += rref.rpc_async().bound_async_add(dst2, x, x, y).wait() + elif mode == RPCExecMode.REMOTE: + ret = rref.remote().static_async_add(dst2, x, x, y).to_here() + ret += rref.remote().class_async_add(dst2, x, x, y).to_here() + ret += rref.remote().bound_async_add(dst2, x, x, y).to_here() self.assertEqual(ret, 3 * 4 * x) @dist_init + def test_async_class_rref_proxy(self): + self._test_test_async_class_rref_proxy() @dist_init + def test_async_class_rref_proxy_async(self): + self._test_test_async_class_rref_proxy(mode=RPCExecMode.ASYNC) @dist_init + def test_async_class_rref_proxy_remote(self): + self._test_test_async_class_rref_proxy(mode=RPCExecMode.REMOTE) def _test_async_function_multi(self, fn, mode=RPCExecMode.SYNC): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) num = 20 + step = 3 + args = (dst2, torch.ones(2, 2), num, step) + ret = self._run_func_in_mode(dst1, fn, mode, args=args) + self.assertEqual(ret, torch.ones(2, 2) + num * step) @dist_init + def test_async_function_multi_chained(self): + self._test_async_function_multi(async_add_chained_multi) @dist_init + def test_async_function_multi_chained_async(self): + self._test_async_function_multi( + async_add_chained_multi, + RPCExecMode.ASYNC + ) @dist_init + def test_async_function_multi_chained_remote(self): + self._test_async_function_multi( + async_add_chained_multi, + RPCExecMode.REMOTE + ) @dist_init + def test_async_function_multi_fanout(self): + self._test_async_function_multi(async_add_multi_fanout) @dist_init + def test_async_function_multi_fanout_async(self): + self._test_async_function_multi( + async_add_multi_fanout, + RPCExecMode.ASYNC + ) @dist_init + def test_async_function_multi_fanout_remote(self): + self._test_async_function_multi( + async_add_multi_fanout, + RPCExecMode.REMOTE + ) def _test_return_future(self, mode): + with self.assertRaisesRegex( + RuntimeError, + "MSG" + ): + self._run_func_in_mode( + worker_name((self.rank + 1) % self.world_size), + return_future, + mode + ) @dist_init + def test_return_future(self): + self._test_return_future(RPCExecMode.SYNC) @dist_init + def test_return_future_async(self): + self._test_return_future(RPCExecMode.ASYNC) @dist_init + def test_return_future_remote(self): + self._test_return_future(RPCExecMode.REMOTE) @dist_init + def test_rref_timeout(self): + if self.rank != 0: + return dst_rank = (self.rank + 1) % self.world_size + dst_worker = "MSG".format(dst_rank) + rref = rpc.remote(dst_worker, my_sleep_func, args=(2, ), timeout=0.01) + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + rref._get_future().wait() + wait_until_pending_futures_and_users_flushed() + with self.assertRaisesRegex(RuntimeError, "MSG"): + rref.to_here() wait_until_owners_and_forks_on_rank(1, 1, rank=1) @dist_init(setup_rpc=False) + def test_init_pg_then_rpc(self): + dist.init_process_group( + backend="MSG", + init_method=self.init_method, + rank=self.rank, + world_size=self.world_size, + ) rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) next_rank = (self.rank + 1) % self.world_size + ret = rpc.rpc_sync(worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1)) + self.assertEqual(ret, torch.ones(2, 2) + 1) dist.barrier() rpc.shutdown() @dist_init(setup_rpc=False) + def test_init_rpc_then_pg(self): + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) dist.init_process_group( + backend="MSG", + init_method=self.init_method, + rank=self.rank, + world_size=self.world_size, + ) next_rank = (self.rank + 1) % self.world_size + ret = rpc.rpc_sync(worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1)) + self.assertEqual(ret, torch.ones(2, 2) + 1) dist.barrier() rpc.shutdown() @dist_init + def test_wait_all_with_exception(self): + futs = [] + dst = worker_name((self.rank + 1) % self.world_size) + for _ in range(10): + futs.append(rpc.rpc_async(dst, raise_func)) with self.assertRaisesRegex(ValueError, "MSG"): + ret = torch.futures.wait_all(futs) @dist_init + def test_wait_all_with_partial_exception(self): + futs = [] + dst = worker_name((self.rank + 1) % self.world_size) + for _ in range(10): + futs.append(rpc.rpc_async(dst, torch.add, args=(torch.ones(2), 1))) futs.append(rpc.rpc_async(dst, raise_func)) with self.assertRaisesRegex(ValueError, "MSG"): + ret = torch.futures.wait_all(futs) @dist_init(setup_rpc=False) + def test_init_rpc_twice(self): + initialize_pg(self.init_method, self.rank, self.world_size) rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + rpc.shutdown() dist.barrier() rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) dst = worker_name((self.rank + 1) % self.world_size) + rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1)) + rpc.rpc_sync(dst, foo_add, args=()) rpc.shutdown() def test_wrong_types(self): + with self.assertRaisesRegex( + TypeError, + "MSG", + ): + rpc.init_rpc( + name=worker_name(self.rank), + rank=self.rank, + world_size=self.world_size, + backend="MSG", + ) with self.assertRaisesRegex( + TypeError, + "MSG", + ): + rpc.init_rpc( + name=worker_name(self.rank), + rank=self.rank, + world_size=self.world_size, + backend=self.rpc_backend, + rpc_backend_options={"MSG": self.init_method} + ) def test_cannot_infer_backend_from_options(self): + rpc_backend_options = FooBackendOptions(self.init_method) with self.assertRaisesRegex(TypeError, "MSG"): + rpc.init_rpc( + name=worker_name(self.rank), + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) +class ProcessGroupAgentRpcTest(RpcAgentTestFixture): def test_mismatched_type_for_options(self): + rpc_backend_options = FooBackendOptions(self.init_method) with self.assertRaisesRegex( + TypeError, "MSG" + ): + rpc.init_rpc( + name=worker_name(self.rank), + rank=self.rank, + world_size=self.world_size, + backend=rpc.BackendType.PROCESS_GROUP, + rpc_backend_options=rpc_backend_options, + ) def test_infer_backend_from_options(self): + rpc_backend_options = rpc.ProcessGroupRpcBackendOptions( + init_method=self.init_method + ) with self.assertLogs("MSG", logging.WARNING) as cm: + rpc.init_rpc( + name=worker_name(self.rank), + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) + self.assertIn( + "MSG", + "MSG".join(cm.output), + ) self.assertIsInstance(rpc.api._get_current_rpc_agent(), rpc.ProcessGroupAgent) def test_logs_deprecation_warning(self): + with self.assertLogs("MSG", logging.WARNING) as cm: + rpc.init_rpc( + name=worker_name(self.rank), + rank=self.rank, + world_size=self.world_size, + backend=rpc.BackendType.PROCESS_GROUP, + rpc_backend_options=self.rpc_backend_options, + ) + self.assertIn( + "MSG", + "MSG".join(cm.output), + ) def test_single_threaded_rref_owner(self): + dist.init_process_group( + backend="MSG", + init_method=self.init_method, + rank=self.rank, + world_size=self.world_size, + ) caller_rank = 0 + callee_rank = 1 + rpc_backend_options = rpc.ProcessGroupRpcBackendOptions( + init_method=self.rpc_backend_options.init_method, + num_send_recv_threads=1 + ) if self.rank == callee_rank else self.rpc_backend_options rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) if self.rank == caller_rank: + dst = worker_name(callee_rank) + rrefs = [] info = rpc.rpc_sync(dst, get_rref_debug_info) + self.assertEqual(0, int(info["MSG"])) for i in range(20): + rrefs.append( + rpc.remote(dst, delayed_add, args=(torch.zeros(2, 2), i)) + ) futs = [] + for i in range(len(rrefs)): + futs.append( + rpc.rpc_async(dst, my_rref_function, args=(rrefs[i], rrefs[i])) + ) for i in range(len(futs)): + self.assertEqual(2 * (torch.zeros(2, 2) + i), futs[i].wait()) info = rpc.rpc_sync(dst, get_rref_debug_info) + num_owner_rrefs = int(info["MSG"]) + self.assertEqual(len(futs), num_owner_rrefs) del futs + del rrefs while num_owner_rrefs > 0: + info = rpc.rpc_sync(dst, get_rref_debug_info) + num_owner_rrefs = int(info["MSG"]) + time.sleep(0.01) dist.barrier() + rpc.shutdown() def test_single_threaded_rref_to_here(self): + dist.init_process_group( + backend="MSG", + init_method=self.init_method, + rank=self.rank, + world_size=self.world_size, + ) caller_rank = 0 + callee_rank = 1 + rpc_backend_options = rpc.ProcessGroupRpcBackendOptions( + init_method=self.rpc_backend_options.init_method, + num_send_recv_threads=1 + ) if self.rank == callee_rank else self.rpc_backend_options rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) if self.rank == caller_rank: + dst = worker_name(callee_rank) + rrefs = [] info = rpc.rpc_sync(dst, get_rref_debug_info) + self.assertEqual(0, int(info["MSG"])) for i in range(20): + rrefs.append( + rpc.remote(dst, delayed_add, args=(torch.zeros(2, 2), i)) + ) for i in range(len(rrefs)): + self.assertEqual(torch.zeros(2, 2) + i, rrefs[i].to_here()) info = rpc.rpc_sync(dst, get_rref_debug_info) + num_owner_rrefs = int(info["MSG"]) + self.assertEqual(len(rrefs), num_owner_rrefs) del rrefs while num_owner_rrefs > 0: + info = rpc.rpc_sync(dst, get_rref_debug_info) + num_owner_rrefs = int(info["MSG"]) + time.sleep(0.01) dist.barrier() + rpc.shutdown() @dist_init + def test_process_group_debug_info(self): + rpc.enable_gil_profiling(True) + initialize_pg(self.init_method, self.rank, self.world_size) + NUM_THREAD = self.rpc_backend_options.num_send_recv_threads info = rpc.api._get_current_rpc_agent().get_debug_info() + self.assertIn("MSG", info) + self.assertIn("MSG", info) + self.assertIn("MSG", info) + self.assertIn("MSG", info) + self.assertEqual(int(info["MSG"]), 0) + self.assertEqual(int(info["MSG"]), NUM_THREAD) + self.assertEqual(int(info["MSG"]), NUM_THREAD) + dist.barrier() + dst_rank = (self.rank + 1) % self.world_size + fut = rpc.rpc_async( + worker_name(dst_rank), set_and_check_done, args=(dst_rank,) + ) + self.assertEqual(self.rank, VALUE_FUTURE.result()) info = rpc.api._get_current_rpc_agent().get_debug_info() + self.assertIn("MSG", info) + self.assertIn("MSG", info) + self.assertIn("MSG", info) + self.assertIn("MSG", info) + self.assertGreaterEqual(float(info["MSG"]), 0) + self.assertEqual(int(info["MSG"]), 1) + self.assertEqual(int(info["MSG"]), NUM_THREAD) + num_idle_threads = int(info["MSG"]) + self.assertTrue(num_idle_threads in [NUM_THREAD - 1, NUM_THREAD - 2]) dist.barrier() DONE_FUTURE.set_result(self.rank) + self.assertEqual(dst_rank, fut.wait()) dist.barrier() info = rpc.api._get_current_rpc_agent().get_debug_info() + self.assertIn("MSG", info) + self.assertIn("MSG", info) + self.assertIn("MSG", info) + self.assertEqual(int(info["MSG"]), 0) + self.assertEqual(int(info["MSG"]), NUM_THREAD) for retry in range(3): + info = rpc.api._get_current_rpc_agent().get_debug_info() + if int(info["MSG"]) == NUM_THREAD: + break + time.sleep(0.1) + self.assertEqual(int(info["MSG"]), NUM_THREAD) dist.barrier() @dist_init(setup_rpc=False) + def test_set_and_get_num_send_recv_threads(self): + NUM_THREADS = 27 + rpc_backend_options = rpc.ProcessGroupRpcBackendOptions( + init_method=self.rpc_backend_options.init_method, + num_send_recv_threads=NUM_THREADS + ) + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) info = rpc.api._get_current_rpc_agent().get_debug_info() + self.assertEqual(int(info["MSG"]), NUM_THREADS) + rpc.shutdown() @dist_init(setup_rpc=False) + def test_process_group_set_default_timeout(self): + timeout = 0.5 + rpc_backend_options = rpc.ProcessGroupRpcBackendOptions( + init_method=self.rpc_backend_options.init_method, + num_send_recv_threads=self.rpc_backend_options.num_send_recv_threads, + rpc_timeout=timeout + ) + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) default_timeout = rpc.get_rpc_timeout() + self.assertEqual(default_timeout, timeout) + rpc.shutdown() @dist_init(setup_rpc=False) + def test_process_group_options_throw_on_timedelta_timeout(self): + from datetime import timedelta timeout = timedelta() + with self.assertRaisesRegex(TypeError, "MSG"): + rpc_backend_options = rpc.ProcessGroupRpcBackendOptions( + init_method=self.rpc_backend_options.init_method, + num_send_recv_threads=self.rpc_backend_options.num_send_recv_threads, + rpc_timeout=timeout, + ) +class FaultyAgentRpcTest(RpcAgentTestFixture): + + @dist_init(messages_to_delay={}) + def test_check_failed_messages(self): + if self.rank == 0: + dst_worker_b = worker_name((self.rank + 1) % self.world_size) + dst_worker_c = worker_name((self.rank + 2) % self.world_size) rref = rpc.remote(dst_worker_b, torch.add, args=(torch.ones(2, 2), torch.ones(2, 2))) + rpc.remote(dst_worker_c, add_rref_to_value, args=(rref, torch.ones(2, 2))) + self.assertEqual(rref.to_here(), torch.add(torch.ones(2, 2), torch.ones(2, 2))) + _delete_all_user_and_unforked_owner_rrefs() @dist_init + def test_verify_backend_options(self): + self.assertEqual(self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_PROCESS_GROUP) + self.assertEqual(self.rpc_backend_options.num_send_recv_threads, 8) + self.assertEqual(self.rpc_backend_options.num_fail_sends, 3) + self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4) + self.assertEqual(len(self.rpc_backend_options.messages_to_delay), 2) + self.assertEqual(self.rpc_backend_options.rpc_timeout, rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) @dist_init(faulty_messages=["MSG", "MSG"]) + def test_custom_faulty_messages(self): + self.assertEqual( + set(["MSG", "MSG"]), + set(self.rpc_backend_options.messages_to_fail), + ) @dist_init(faulty_messages=[]) + def test_no_faulty_messages(self): + self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 0) @dist_init(messages_to_delay={"MSG": 1.5}) + def test_custom_messages_to_delay(self): + self.assertEqual(self.rpc_backend_options.messages_to_delay, {"MSG": 1.5}) def _test_remote_message_dropped_pickle(self, dst=None): + if self.rank != 0: + return + dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size + dst_worker = "MSG".format(dst_rank) + rref = rpc.remote(dst_worker, my_sleep_func, args=(1,)) + wait_until_pending_futures_and_users_flushed() + with self.assertRaisesRegex(RuntimeError, "MSG"): + rref._serialize() + with self.assertRaisesRegex(RuntimeError, "MSG"): + rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 1)) @dist_init(faulty_messages=["MSG"]) + def test_remote_message_dropped_pickle(self): + self._test_remote_message_dropped_pickle() @dist_init(faulty_messages=["MSG"]) + def test_remote_message_dropped_pickle_to_self(self): + self._test_remote_message_dropped_pickle(self.rank) + def _test_remote_message_dropped_timeout(self, func, args, dst=None): + if self.rank != 0: + return dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size + dst_worker = "MSG".format(dst_rank) + rref = rpc.remote(dst_worker, func, args=args) + wait_until_pending_futures_and_users_flushed() + with self.assertRaisesRegex(RuntimeError, "MSG"): + rref.to_here() + @dist_init(faulty_messages=["MSG"]) + def test_builtin_remote_message_dropped_timeout(self): + func = torch.add + args = (torch.tensor(1), torch.tensor(1)) + self._test_remote_message_dropped_timeout(func, args) @dist_init(faulty_messages=["MSG"]) + def test_builtin_remote_message_dropped_timeout_to_self(self): + func = torch.add + args = (torch.tensor(1), torch.tensor(1)) + self._test_remote_message_dropped_timeout(func, args, dst=0) @dist_init(faulty_messages=["MSG"]) + def test_udf_remote_message_dropped_timeout(self): + func = my_sleep_func + args = (2,) + self._test_remote_message_dropped_timeout(func, args) @dist_init(faulty_messages=["MSG"]) + def test_udf_remote_message_dropped_timeout_to_self(self): + func = my_sleep_func + args = (2,) + self._test_remote_message_dropped_timeout(func, args, dst=0) def _test_remote_message_delay_timeout(self, func, args, dst=None): + if self.rank != 0: + return + dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size + dst_worker = "MSG".format(dst_rank) + rref = rpc.remote(dst_worker, func, args=args, timeout=0.001) + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + rref._get_future().wait() wait_until_pending_futures_and_users_flushed() + with self.assertRaisesRegex(RuntimeError, "MSG"): + rref.to_here() if dst_rank != self.rank: + slow_rref = rpc.remote(dst_worker, func, args=args, timeout=2) with self.assertRaisesRegex(RuntimeError, expected_error): + slow_rref.to_here(0.001) + if dst_rank != self.rank: + wait_until_owners_and_forks_on_rank(2, 2, rank=dst_rank) @dist_init(faulty_messages=[], messages_to_delay={"MSG": 2}) + def test_udf_remote_message_delay_timeout(self): + func = my_sleep_func + args = (2,) + self._test_remote_message_delay_timeout(func, args) @dist_init(faulty_messages=[], messages_to_delay={"MSG": 2}) + def test_udf_remote_message_delay_timeout_to_self(self): + func = my_sleep_func + args = (1,) + self._test_remote_message_delay_timeout(func, args, dst=0) @dist_init( + faulty_messages=[], + messages_to_delay={"MSG": 2, "MSG": 1}, + ) + def test_remote_message_builtin_delay_timeout(self): + func = torch.add + args = (torch.tensor(1), torch.tensor(1)) + self._test_remote_message_delay_timeout(func, args) @dist_init( + faulty_messages=[], + messages_to_delay={"MSG": 2, "MSG": 1}, + ) + def test_remote_message_builtin_delay_timeout_to_self(self): + func = torch.add + args = (torch.tensor(1), torch.tensor(1)) + self._test_remote_message_delay_timeout(func, args, dst=0) @dist_init( + faulty_messages=[], + messages_to_delay={"MSG": 2, "MSG": 1}, + ) + def test_remote_message_script_delay_timeout(self): + func = my_script_func + args = (torch.tensor(1),) + self._test_remote_message_delay_timeout(func, args) @dist_init( + faulty_messages=[], + messages_to_delay={"MSG": 2, "MSG": 1}, + ) + def test_remote_message_script_delay_timeout_to_self(self): + func = my_script_func + args = (torch.tensor(1),) + self._test_remote_message_delay_timeout(func, args, dst=0) @dist_init(faulty_messages=[], messages_to_delay={"MSG": 1}) + def test_rref_to_here_timeout(self): + if self.rank != 0: + return dst_rank = (self.rank + 1) % self.world_size + dst_worker = "MSG".format(dst_rank) + rref = rpc.remote( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + rref.to_here(0.01) rref.to_here() @dist_init(faulty_messages=[]) + def test_rpc_builtin_timeout(self): + next_rank = (self.rank + 1) % self.world_size + dst_worker = worker_name(next_rank) + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + rpc.rpc_sync( + dst_worker, + torch.add, + args=(torch.tensor(1), torch.tensor(1)), + timeout=1, + ) fut = rpc.rpc_async( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1 + ) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() self.assertEqual(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC, rpc.get_rpc_timeout()) + fut = rpc.rpc_async( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + fut.wait() rpc._set_rpc_timeout(0.001) + fut = rpc.rpc_async( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() fut = rpc.rpc_async( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=0 + ) + fut.wait() + rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) @dist_init(faulty_messages=[], messages_to_delay={"MSG": 1.5}) + def test_rpc_script_timeout(self): + next_rank = (self.rank + 1) % self.world_size + dst_worker = worker_name(next_rank) + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + rpc.rpc_sync(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1) fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() self.assertEqual(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC, rpc.get_rpc_timeout()) + fut = rpc.rpc_async( + dst_worker, my_script_func, args=(torch.tensor(1),) + ) + fut.wait() rpc._set_rpc_timeout(0.001) + fut = rpc.rpc_async( + dst_worker, my_script_func, args=(torch.tensor(1),) + ) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() rpc._set_rpc_timeout(0.001) + fut = rpc.rpc_async( + dst_worker, my_script_func, args=(torch.tensor(1),), timeout=0 + ) + fut.wait() + rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) class TensorPipeAgentRpcTest(RpcAgentTestFixture): def test_mismatched_type_for_options(self): + rpc_backend_options = FooBackendOptions(self.init_method) with self.assertRaisesRegex( + TypeError, "MSG" + ): + rpc.init_rpc( + name=worker_name(self.rank), + rank=self.rank, + world_size=self.world_size, + backend=rpc.BackendType.TENSORPIPE, + rpc_backend_options=rpc_backend_options, + ) def test_infer_backend_from_options(self): + rpc_backend_options = rpc.TensorPipeRpcBackendOptions( + init_method=self.init_method + ) rpc.init_rpc( + name=worker_name(self.rank), + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) self.assertIsInstance(rpc.api._get_current_rpc_agent(), rpc.TensorPipeAgent) + @dist_init(setup_rpc=False) + def test_set_and_get_num_worker_threads(self): + NUM_THREADS = 27 + rpc_backend_options = rpc.TensorPipeRpcBackendOptions( + init_method=self.rpc_backend_options.init_method, + num_worker_threads=NUM_THREADS + ) + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) info = rpc.api._get_current_rpc_agent().get_debug_info() + self.assertEqual(int(info["MSG"]), NUM_THREADS) + rpc.shutdown() + @dist_init(setup_rpc=False) + def test_tensorpipe_set_default_timeout(self): + timeout = 0.5 + rpc_backend_options = rpc.TensorPipeRpcBackendOptions( + init_method=self.rpc_backend_options.init_method, + num_worker_threads=self.rpc_backend_options.num_worker_threads, + rpc_timeout=timeout + ) + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) default_timeout = rpc.get_rpc_timeout() + self.assertEqual(default_timeout, timeout) + rpc.shutdown() + @dist_init(setup_rpc=False) + def test_tensorpipe_options_throw_on_timedelta_timeout(self): + from datetime import timedelta timeout = timedelta() + with self.assertRaisesRegex(TypeError, "MSG"): + rpc_backend_options = rpc.TensorPipeRpcBackendOptions( + init_method=self.rpc_backend_options.init_method, + num_worker_threads=self.rpc_backend_options.num_worker_threads, + rpc_timeout=timeout, + ) +import torch.distributed.rpc as rpc +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) +class TensorPipeRpcAgentTestFixture(RpcAgentTestFixture): + @property + def rpc_backend(self): + return rpc.backend_registry.BackendType[ + "MSG" + ] @property + def rpc_backend_options(self): + return rpc.backend_registry.construct_rpc_backend_options( + self.rpc_backend, + init_method=self.init_method, + ) def get_shutdown_error_regex(self): + error_regexes = ["MSG"] + return "MSG".join(["MSG".format(error_str) for error_str in error_regexes]) def get_timeout_error_regex(self): + return "MSG" +from typing import Tuple, Dict import torch +import torch.distributed.autograd as dist_autograd +import torch.distributed.rpc as rpc +from torch import Tensor +from torch.distributed.rpc import rpc_async +from torch.testing import FileCheck +from torch.testing._internal.dist_utils import dist_init, worker_name +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) +@torch.jit.script +def local_add(t1, t2): + return torch.add(t1, t2) +@torch.jit.script +def remote_add(t1, t2, dst: str): + return rpc_async(dst, local_add, (t1, t2)).wait() +@torch.jit.script +def fork_add(t1, t2, dst: str): + fut = torch.jit._fork(remote_add, t1, t2, dst) + return torch.jit._wait(fut) +class JitDistAutogradTest(RpcAgentTestFixture): + @dist_init + def test_get_gradients(self): + dst_rank = self.rank @torch.jit.script + def dist_get_gradients(context_id): + return dist_autograd.get_gradients(context_id) FileCheck().check("MSG").run(str(dist_get_gradients.graph)) + with dist_autograd.context() as context_id: + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + t3 = torch.add(t1, t2) dist_autograd.backward(context_id, [t3.sum()]) + grads = dist_get_gradients(context_id) self.assertEqual(2, len(grads)) + self.assertIn(t1, grads) + self.assertIn(t2, grads) + self.assertEqual(torch.ones(3, 3), grads[t1]) + self.assertEqual(torch.ones(3, 3), grads[t2]) @dist_init + def test_dist_backward(self): + if self.rank != 0: + return @torch.jit.script + def dist_backward_script(context_id: int, loss: torch.Tensor): + dist_autograd.backward(context_id, [loss]) FileCheck().check("MSG").run(str(dist_backward_script.graph)) + with dist_autograd.context() as context_id: + t1 = torch.rand(3, 3, requires_grad=True) + t2 = torch.rand(3, 3, requires_grad=True) + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + loss = rpc.rpc_sync(dst_worker_name, torch.add, args=(t1, t2)).sum() + dist_backward_script(context_id, loss) @dist_init + def test_jit_fork_within_context(self): + with dist_autograd.context() as context_id: + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + res = fork_add(t1, t2, dst_worker_name) + loss = res.sum() + dist_autograd.backward(context_id, [loss]) grads = dist_autograd.get_gradients(context_id) + self.assertEqual(2, len(grads)) + self.assertIn(t1, grads) + self.assertIn(t2, grads) @dist_init + def test_restore_context_after_swtich_to_jit_thread(self): + if self.rank != 0: + return @torch.jit.script + def forward_script( + context_id: int, dst_worker_name: str, t1: Tensor, t2: Tensor + ) -> Tuple[Tensor, Tensor]: + res1_fut = rpc.rpc_async(dst_worker_name, local_add, (t1, t1)) + res1 = res1_fut.wait() + loss1 = res1.sum() res2_fut = rpc.rpc_async(dst_worker_name, local_add, (t2, t2)) + res2 = res2_fut.wait() + loss2 = res2.sum() return loss1, loss2 with dist_autograd.context() as context_id: + t1 = torch.ones((2, 3), requires_grad=True) + t2 = torch.ones((2, 3), requires_grad=True) + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + loss0, loss1 = forward_script(context_id, dst_worker_name, t1, t2) + dist_autograd.backward(context_id, [loss0, loss1]) + grad0, grad1 = dist_autograd.get_gradients(context_id) + self.assertEqual(grad0, grad1) +import time +import io +from typing import Dict, List, Tuple, Any import torch +import torch.distributed as dist +import torch.distributed.rpc as rpc +from torch import Tensor +from torch.autograd.profiler import record_function +from torch.distributed.rpc import RRef +from torch.distributed.rpc.internal import RPCExecMode, _build_rpc_profiling_key +from torch.futures import Future +from torch.testing._internal.common_utils import TemporaryFileName +from torch.testing._internal.dist_utils import ( + dist_init, + get_function_event, + initialize_pg, + worker_name, +) +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) +def sleep(t): + time.sleep(t) +def rpc_return_rref(dst): + return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)) +@torch.jit.script +def rref_local_value(rref: RRef[Tensor]) -> Tensor: + return rref.local_value() +@torch.jit.script +def list_create() -> List[int]: + global_list = [1, 2, 3] + return global_list +@torch.jit.script +def rref_list_mutate(rref: RRef[List[int]]) -> None: + rref.local_value().append(4) + rref.to_here().append(5) + rref.to_here(5.0).append(6) +def return_value(value: int) -> int: + return value +class RRefAPITest: + @dist_init + def test_rref_is_owner(self): + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + rref_var = rpc_return_rref(dst_worker_name) @torch.jit.script + def rref_tensor_is_owner(rref_var: RRef[Tensor]) -> bool: + return rref_var.is_owner() res = rref_tensor_is_owner(rref_var) + self.assertEqual(res, False) @dist_init + def test_rref_local_value(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) + rref = rpc_return_rref(dst_worker_name) with self.assertRaisesRegex( + RuntimeError, r"MSG" + ): + rref_local_value(rref) ret = ret = rpc.rpc_sync(dst_worker_name, rref_local_value, (rref,)) + self.assertEqual(ret, torch.add(torch.ones(2, 2), 1)) @dist_init + def test_local_rref_local_value(self): + if self.rank != 0: + return dst_worker_name = worker_name(self.rank) + rref = rpc.remote(dst_worker_name, return_value, (5,), {}) ret = rref_local_value(rref) + self.assertEqual(ret, 5) def _create_rref(self): + owner_rank = (self.rank + 2) % self.world_size + return rpc.remote( + worker_name(owner_rank), torch.add, args=(torch.zeros(2, 2), 1) + ) @dist_init + def test_user_rrefs_confirmed(self): + dst_rank = (self.rank + 1) % self.world_size + rref = self._create_rref() + ret = rpc.rpc_sync( + worker_name(dst_rank), script_check_rref_confirmed, args=(rref,) + ) + self.assertEqual(ret, True) @dist_init + def test_user_rrefs_confirmed_remote(self): + dst_rank = (self.rank + 1) % self.world_size + rref = self._create_rref() + ret_rref = rpc.remote( + worker_name(dst_rank), script_check_rref_confirmed, args=(rref,) + ) + self.assertEqual(ret_rref.to_here(), True) @dist_init + def test_rref_list_mutate(self): + dst = worker_name((self.rank + 1) % self.world_size) + list_rref = rpc.remote(dst, list_create) rpc.rpc_sync(dst, rref_list_mutate, args=(list_rref,)) + self.assertEqual(list_rref.to_here(), [1, 2, 3, 4, 5, 6]) +@torch.jit.script +def no_arg(): + return 0 +@torch.jit.script +def one_arg(value): + return value + 1 @torch.jit.script +def script_add_ones(x): + return torch.add(x, torch.ones(1)) @torch.jit.script +def script_add_ones_with_record_function(x, block: str): + with record_function(block): + return torch.add(x, torch.ones(1)) +@torch.jit.script +def record_function_on_caller_rpc_async(dst_worker_name: str, block: str) -> Tensor: + t: Tensor = torch.ones(1) + with record_function(block) as rf: + fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, )) + fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, )) + res = fut1.wait() + fut2.wait() + return res @torch.jit.script +def script_fork_wait_udf(tensor): + fut = torch.jit._fork(script_add_ones, tensor) + x = torch.jit._wait(fut) + return x +@torch.jit.script +def rref_to_here(rref_var: RRef[Tensor]) -> Tensor: + return rref_var.to_here() +@torch.jit.script +def return_rref(rref_var: RRef[Tensor]) -> RRef[Tensor]: + return rref_var +@torch.jit.script +def script_raise_func(value): + if value.numel() == 2: + raise ValueError("MSG") + return value + 1 +@torch.jit.script +def script_fork_wait_throw(invalue): + fut = torch.jit._fork(script_raise_func, invalue) + value = torch.jit._wait(fut) + return value +@torch.jit.script +def call_rpc_with_profiling(handle: Tensor, dst_worker_name: str) -> Tensor: + + + + fut = rpc.rpc_async(dst_worker_name, one_arg, (torch.tensor(1),)) + torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut) + ret = fut.wait() + return ret @torch.jit.script +def call_rpc_torchscript_with_record_function(dst_worker_name: str, block: str) -> Tensor: + fut = rpc.rpc_async(dst_worker_name, script_add_ones_with_record_function, (torch.tensor(1), block)) + return fut.wait() +@torch.jit.script +def call_fork_with_profiling(handle: Tensor) -> Tensor: + + + + fut = torch.jit._fork(one_arg, torch.tensor(1)) + torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut) + ret = fut.wait() + return ret +class MyScriptModuleWithRRefs(torch.jit.ScriptModule): + def __init__(self, dst_worker): + super().__init__() + self.rrefs = [] + for _ in range(4): + self.rrefs.append(rpc_return_rref(dst_worker)) @torch.jit.script_method + def forward(self) -> Tensor: + res_tensor = torch.ones(2, 2) + for rref in self.rrefs: + res_tensor += rref.to_here() return res_tensor +@torch.jit.ignore +def rref_python_annotation(rref_var: RRef[Tensor]) -> RRef[Tensor]: + return rref_var +@torch.jit.script +def rref_script_annotation(rref_var: RRef[Tensor]) -> Tensor: + return rref_python_annotation(rref_var).to_here() +class RRefTypingTest: + @dist_init + def test_rref_as_arg_and_return(self): + n = self.rank + 1 + dst_rank = n % self.world_size + local_ret = one_arg(torch.ones(2, 2)) rref = rpc.remote(worker_name(self.rank), one_arg, args=(torch.ones(2, 2),)) ret = rpc.rpc_sync(worker_name(dst_rank), rref_to_here, args=(rref,)) + self.assertEqual(ret, local_ret) rref1 = rpc.rpc_sync(worker_name(dst_rank), return_rref, args=(rref,)) + self.assertEqual(rref1.to_here(), local_ret) rref2 = rpc.remote(worker_name(dst_rank), rref_to_here, args=(rref,)) + self.assertEqual(rref2.to_here(), local_ret) rref3 = rpc.remote(worker_name(dst_rank), return_rref, args=(rref,)) + self.assertEqual(rref3.to_here().to_here(), local_ret) @dist_init + def test_my_script_module_with_rrefs(self): + n = self.rank + 1 + dst_rank = n % self.world_size module_with_rrefs = MyScriptModuleWithRRefs(worker_name(dst_rank)) + res = module_with_rrefs() + self.assertEqual(res, torch.ones(2, 2) * 9) @dist_init + def test_rref_python_annotation(self): + n = self.rank + 1 + dst_rank = n % self.world_size + rref_var = rpc_return_rref(worker_name(dst_rank)) res = rref_script_annotation(rref_var) + self.assertEqual(res, torch.ones(2, 2) + 1) +class FutureTypingTest: + @dist_init + def test_future_passed_between_python_and_jit(self): + dst_rank = (self.rank + 1) % self.world_size + inputs = (torch.tensor([1, 1]), torch.tensor([2, 2])) + ret_fut = rpc.rpc_async(worker_name(dst_rank), two_args_two_kwargs, args=inputs) + expected_res = torch.tensor([10, 10]) @torch.jit.script + def future_wait_in_script(fut: Future[Tensor]) -> Tensor: + return fut.wait() self.assertEqual(future_wait_in_script(ret_fut), expected_res) @torch.jit.script + def future_return_to_python( + dst_rank: int, inputs: Tuple[Tensor, Tensor] + ) -> Future[Tensor]: + return rpc.rpc_async( + "MSG".format(dst_rank), two_args_two_kwargs, inputs + ) fut_res = future_return_to_python(dst_rank, inputs) + self.assertEqual(fut_res.wait(), expected_res) @dist_init + def test_future_python_annotation(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) + input_0 = torch.ones(2, 2) + input_1 = 1 + expected_res = torch.add(input_0, input_1) @torch.jit.ignore + def python_return_future() -> Future[Tensor]: + fut = rpc.rpc_async(dst_worker_name, torch.add, (input_0, input_1), {}) + return fut @torch.jit.script + def script_use_future() -> Tensor: + fut = python_return_future() + return fut.wait() res = script_use_future() + self.assertEqual(res, expected_res) +@torch.jit.script +class MyScriptClass: + def __init__(self, a: int): + self.a = a def get_value(self) -> int: + return self.a +@torch.jit.interface +class MyModuleInterface(torch.nn.Module): + def forward(self) -> Tensor: + pass +class MyScriptModule(torch.jit.ScriptModule): + def __init__(self, rank): + super().__init__() + self.a = torch.ones(rank) @torch.jit.script_method + def forward(self) -> Tensor: + return self.a +def owner_create_rref_my_script_class(a): + return rpc.RRef(MyScriptClass(a)) +def owner_create_rref_my_script_module(a): + return rpc.RRef(MyScriptModule(a), MyModuleInterface) +@torch.jit.script +def script_rref_get_value_my_script_class(rref: RRef[MyScriptClass]) -> int: + return rref.to_here().get_value() +@torch.jit.script +def script_rref_run_forward_my_script_module(rref: RRef[MyModuleInterface]) -> Tensor: + return rref.to_here().forward() +class LocalRRefTest: + @dist_init + def test_create_local_script_class_rref_in_py(self): + if self.rank != 0: + return rref_script_class = rpc.RRef(MyScriptClass(self.rank)) + ret = rref_script_class.to_here().get_value() + self.assertEqual(ret, self.rank) @dist_init + def test_create_local_script_module_rref_in_py(self): + if self.rank != 0: + return rref_script_module = rpc.RRef(MyScriptModule(self.rank), MyModuleInterface) + ret = rref_script_module.to_here().forward() + self.assertEqual(ret, torch.ones(self.rank)) with self.assertRaisesRegex( + RuntimeError, + ( + "MSG" + "MSG" + ), + ): + rref_script_module = rpc.RRef(MyScriptModule(self.rank)) @dist_init + def test_return_local_script_class_rref_in_py_and_use_in_script(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) rref = rpc.rpc_sync( + dst_worker_name, owner_create_rref_my_script_class, args=(self.rank,) + ) def use_rref_on_owner(rref: RRef[MyScriptClass]) -> int: + args = (rref,) + kwargs: Dict[str, Any] = {} + fut = rpc.rpc_async( + rref.owner(), script_rref_get_value_my_script_class, args, kwargs + ) + ret = fut.wait() + return ret ret = use_rref_on_owner(rref) + self.assertEqual(ret, self.rank) use_rref_on_owner_script = torch.jit.script(use_rref_on_owner) + ret = use_rref_on_owner_script(rref) + self.assertEqual(ret, self.rank) @dist_init + def test_return_local_script_module_rref_in_py_and_use_in_script(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) rref = rpc.rpc_sync( + dst_worker_name, owner_create_rref_my_script_module, args=(self.rank,) + ) def use_rref_on_owner(rref: RRef[MyModuleInterface]) -> Tensor: + args = (rref,) + kwargs: Dict[str, Any] = {} + fut = rpc.rpc_async( + rref.owner_name(), + script_rref_run_forward_my_script_module, + args, + kwargs, + ) + ret = fut.wait() + return ret ret = use_rref_on_owner(rref) + self.assertEqual(ret, torch.ones(self.rank)) use_rref_on_owner_script = torch.jit.script(use_rref_on_owner) + ret = use_rref_on_owner_script(rref) + self.assertEqual(ret, torch.ones(self.rank)) +def python_function(): + return 0 +@torch.jit.script +def two_args_two_kwargs( + first_arg, + second_arg, + first_kwarg=torch.tensor([3, 3]), + second_kwarg=torch.tensor([4, 4]), +): + return first_arg + second_arg + first_kwarg + second_kwarg +@torch.jit.script +def assorted_types_args_kwargs( + tensor_arg: Tensor, + str_arg: str, + int_arg: int, + tensor_kwarg: Tensor = torch.tensor([2, 2]), + str_kwarg: str = "MSG", + int_kwarg: int = 2, +): + return tensor_arg + tensor_kwarg, str_arg + str_kwarg, int_arg + int_kwarg +@torch.jit.script +def raise_script(): + raise RuntimeError("MSG") +@torch.jit.script +def script_rpc_async_call( + dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor] +): + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) + ret = fut.wait() + return ret @torch.jit.script +def script_rpc_sync_call( + dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor] +): + res = rpc.rpc_sync(dst_worker_name, two_args_two_kwargs, args, kwargs) + return res @torch.jit.script +def script_rpc_remote_call( + dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor] +): + rref_res = rpc.remote(dst_worker_name, two_args_two_kwargs, args, kwargs) + return rref_res.to_here() class JitRpcOpTest: + + @dist_init + def test_all_kwargs_are_populated_by_defaults(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) args = (torch.tensor([1, 1]), torch.tensor([2, 2])) + kwargs = {} for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]: + ret = script_op( + dst_worker_name, args, kwargs + ) + self.assertEqual(ret, torch.tensor([10, 10])) @dist_init + def test_some_kwargs_are_populated_by_defaults(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) args = (torch.tensor([1, 1]), torch.tensor([2, 2])) + kwargs = {"MSG": torch.tensor([2, 2])} for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]: + ret = script_op( + dst_worker_name, args, kwargs + ) + self.assertEqual(ret, torch.tensor([9, 9])) @dist_init + def test_no_kwargs_are_populated_by_defaults(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) args = (torch.tensor([1, 1]), torch.tensor([2, 2])) + kwargs = { + "MSG": torch.tensor([2, 2]), + "MSG": torch.tensor([3, 3]), + } + for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]: + ret = script_op( + dst_worker_name, args, kwargs + ) + self.assertEqual(ret, torch.tensor([8, 8])) @dist_init + def test_args_and_kwargs_contain_different_types(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) @torch.jit.script + def script_rpc_async_call_with_assorted_types( + dst_worker_name: str, + ): + args = (torch.tensor([1, 1]), "MSG", 1) + kwargs: Dict[str, Any] = { + "MSG": torch.tensor([3, 3]), + "MSG": "MSG", + "MSG": 3, + } + fut = rpc.rpc_async( + dst_worker_name, assorted_types_args_kwargs, args, kwargs + ) + ret = fut.wait() + return ret ret = script_rpc_async_call_with_assorted_types( + dst_worker_name + ) + self.assertEqual(ret, (torch.tensor([4, 4]), "MSG", 4)) @dist_init + def test_kwargs_not_passed(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) @torch.jit.script + def script_rpc_async_call_without_kwargs_passed( + dst_worker_name: str, + ): + args = () + fut = rpc.rpc_async(dst_worker_name, no_arg, args) + ret = fut.wait() + return ret ret = script_rpc_async_call_without_kwargs_passed( + dst_worker_name + ) + self.assertEqual(ret, 0) @dist_init + def test_args_kwargs_are_neither_passed(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) @torch.jit.script + def script_rpc_async_call_without_args_kwargs_passed( + dst_worker_name: str, + ): + fut = rpc.rpc_async(dst_worker_name, no_arg) + ret = fut.wait() + return ret ret = script_rpc_async_call_without_args_kwargs_passed( + dst_worker_name + ) + self.assertEqual(ret, 0) @dist_init + def test_less_than_needed_args_are_specified(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) with self.assertRaisesRegex(RuntimeError, "MSG"): @torch.jit.script + def script_rpc_async_call_with_less_args( + dst_worker_name: str, + ): + args = (torch.tensor([1, 1]),) + kwargs = {} + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) + ret = fut.wait() + return ret @dist_init + def test_more_than_needed_args_are_specified(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) with self.assertRaisesRegex( + RuntimeError, + "MSG", + ): @torch.jit.script + def script_rpc_async_call_with_more_args( + dst_worker_name: str, + ): + args = ( + torch.tensor([1, 1]), + torch.tensor([2, 2]), + torch.tensor([3, 3]), + torch.tensor([4, 4]), + torch.tensor([5, 5]), + ) + kwargs = {} + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) + ret = fut.wait() + return ret @dist_init + def test_unexepected_kwarg_is_specified(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) @torch.jit.script + def script_rpc_async_call_with_unexpected_kwarg( + dst_worker_name: str, + ): + args = (torch.tensor([1, 1]), torch.tensor([2, 2])) + kwargs = {"MSG": torch.tensor([1, 1])} + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) + ret = fut.wait() + return ret with self.assertRaisesRegex( + RuntimeError, "MSG" + ): + ret = script_rpc_async_call_with_unexpected_kwarg( + dst_worker_name + ) + self.assertEqual(ret, 0) @dist_init + def test_call_python_function_remotely_from_script_not_supported(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) @torch.jit.script + def rpc_async_call_remote_py_function_in_torchscript(dst_worker_name: str): + args = () + kwargs = {} + fut = rpc.rpc_async(dst_worker_name, python_function, args, kwargs) + ret = fut.wait() + return ret with self.assertRaisesRegex( + RuntimeError, "MSG" + ): + ret = rpc_async_call_remote_py_function_in_torchscript(dst_worker_name) + self.assertEqual(ret, 0) @dist_init + def test_call_script_function_that_raises_remotely_from_script(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) @torch.jit.script + def rpc_async_call_remote_raising_torchscript_in_torchscript( + dst_worker_name: str, + ): + args = () + kwargs = {} + fut = rpc.rpc_async(dst_worker_name, raise_script, args, kwargs) + ret = fut.wait() + return ret with self.assertRaisesRegex(RuntimeError, "MSG"): + ret = rpc_async_call_remote_raising_torchscript_in_torchscript( + dst_worker_name + ) + self.assertEqual(ret, 0) @dist_init + def test_call_script_function_that_not_exists_remotely_from_script(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) @torch.jit.script + def nonexisting_script(): + return 0 @torch.jit.script + def rpc_async_call_remote_nonexisting_torchscript_in_torchscript( + dst_worker_name: str, + ): + args = () + kwargs = {} + fut = rpc.rpc_async(dst_worker_name, nonexisting_script, args, kwargs) + ret = fut.wait() + return ret with self.assertRaisesRegex( + RuntimeError, "MSG" + ): + ret = rpc_async_call_remote_nonexisting_torchscript_in_torchscript( + dst_worker_name + ) + self.assertEqual(ret, 0) +@torch.jit.ignore +def my_script_module_init(rank: int) -> MyModuleInterface: + return MyScriptModule(rank) +@torch.jit.script +def construct_my_script_module(rank: int) -> MyModuleInterface: + return my_script_module_init(rank) +@torch.jit.script +def run_ref_script_module( + ref_script_module: RRef[MyModuleInterface], t: Tensor +) -> Tensor: + module = ref_script_module.to_here() + return module.forward() + t +@torch.jit.script +def script_check_rref_confirmed(rref: RRef[Tensor]) -> bool: + return rref.confirmed_by_owner() +@torch.jit.script +def save_rref(rref_var: RRef[Tensor], fname: str) -> None: + torch.save(rref_var, fname) +@torch.jit.script +def script_add(x: Tensor, y: Tensor) -> Tensor: + return x + y +@rpc.functions.async_execution +@torch.jit.script +def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]: + return rpc.rpc_async(to, script_add, (x, y)) +@rpc.functions.async_execution +@torch.jit.script +def async_wrong_type() -> Tensor: + return torch.zeros(2) +def load_script_module_with_pickled_rref(pickled_script_module): + f = io.BytesIO(pickled_script_module) + m = torch.jit.load(f) + return m() +class JitRpcTest( + RRefAPITest, + RRefTypingTest, + LocalRRefTest, + JitRpcOpTest, + FutureTypingTest, + RpcAgentTestFixture, +): + @dist_init + def test_torchscript_function(self): + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + local_ret = one_arg(torch.ones(2, 2)) + ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(torch.ones(2, 2),)) + self.assertEqual(ret, local_ret) + rref = rpc.remote(dst_worker_name, one_arg, args=(torch.ones(2, 2),)) + self.assertEqual(rref.to_here(), local_ret) + local_rref = rpc.remote( + worker_name(self.rank), one_arg, args=(torch.ones(2, 2),) + ) + self.assertEqual(local_rref.to_here(), local_ret) @dist_init + def test_torchscript_function_exception(self): + dst_worker_name = worker_name((self.rank + 1) % self.world_size) + with self.assertRaisesRegex(RuntimeError, r"MSG"): + ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(10, 20)) with self.assertRaisesRegex(RuntimeError, r"MSG"): + rref = rpc.remote(dst_worker_name, one_arg, args=(10, 20)) @dist_init + def test_torchscript_functions_not_supported(self): + dst_worker_name = worker_name((self.rank + 1) % self.world_size) my_local_script_module = MyScriptModule(self.rank) initialize_pg(self.init_method, self.rank, self.world_size) + dist.barrier() ret = rpc.rpc_sync(dst_worker_name, MyScriptClass, args=(self.rank,)) with self.assertRaisesRegex(RuntimeError, "MSG"): + ret = rpc.rpc_sync(dst_worker_name, MyScriptModule, args=(self.rank,)) with self.assertRaisesRegex(TypeError, "MSG"): + ret = rpc.rpc_async( + dst_worker_name, my_local_script_module.forward, args=() + ) @dist_init + def test_remote_script_module(self): + import torch.distributed.rpc.api as api api._ignore_rref_leak = True local_ret = torch.ones(self.rank) + torch.ones(self.rank) n = self.rank + 1 + dst_rank = n % self.world_size + remote_ref = rpc.remote( + worker_name(dst_rank), construct_my_script_module, args=(self.rank,) + ) ret = rpc.rpc_sync( + worker_name(dst_rank), + run_ref_script_module, + args=(remote_ref, torch.ones(self.rank)), + ) + self.assertEqual(ret, local_ret) with self.assertRaisesRegex( + RuntimeError, + "MSG", + ): + ret = rpc.rpc_sync( + worker_name(self.rank), + run_ref_script_module, + args=(remote_ref, torch.ones(self.rank)), + ) @dist_init + def test_load_script_module_with_pickled_rref(self): + dst_name = worker_name((self.rank + 1) % self.world_size) + m1 = MyScriptModuleWithRRefs(dst_name) + m2 = MyScriptModuleWithRRefs(dst_name) f = io.BytesIO() rpc._enable_jit_rref_pickle() + torch.jit.save(m1, f) + rpc._disable_jit_rref_pickle() out1 = rpc.rpc_sync( + dst_name, + load_script_module_with_pickled_rref, + args=(f.getvalue(),) + ) + out2 = m2() + self.assertEqual(out1, out2) @dist_init + def test_rref_jit_pickle_not_supported(self): + n = self.rank + 1 + dst_rank = n % self.world_size + rref_var = rpc_return_rref(worker_name(dst_rank)) + with TemporaryFileName() as fname: + with self.assertRaisesRegex( + RuntimeError, "MSG" + ): + save_rref(rref_var, fname) @dist_init + def test_remote_script_throw(self): + rref = rpc.remote( + worker_name((self.rank + 1) % self.world_size), + script_raise_func, + args=(torch.ones(2),), + ) + with self.assertRaisesRegex(Exception, "MSG"): + rref.to_here() @dist_init + def test_remote_script_udf(self): + rref = rpc.remote( + worker_name((self.rank + 1) % self.world_size), + script_fork_wait_udf, + args=(torch.ones(2),), + ) + self.assertEqual(rref.to_here(), torch.ones(2) * 2) @dist_init + def test_async_script_udf(self): + future = rpc.rpc_async( + worker_name((self.rank + 1) % self.world_size), + script_fork_wait_udf, + args=(torch.ones(2),), + ) + self.assertEqual(future.wait(), torch.ones(2) * 2) @dist_init + def test_callback_simple(self): + def callback(fut): + return fut.wait() + 1 future = rpc.rpc_async( + worker_name((self.rank + 1) % self.world_size), + script_fork_wait_udf, + args=(torch.ones(2),), + ).then(callback) + self.assertEqual(future.wait(), torch.ones(2) * 2 + 1) @dist_init + def test_callback_chain(self): + n = self.rank + 1 + dst = worker_name(n % self.world_size) def callback(fut): + return fut.wait() + 1 fut = rpc.rpc_async( + worker_name(n % self.world_size), one_arg, args=(torch.ones(n, n),) + ) num_cbs = 20 + for _ in range(num_cbs): + fut = fut.then(callback) self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs) @dist_init + def test_async_script_throw(self): + future = rpc.rpc_async( + worker_name((self.rank + 1) % self.world_size), + script_fork_wait_throw, + args=(torch.ones(2),), + ) + with self.assertRaisesRegex(Exception, "MSG"): + future.wait() @dist_init + def test_callback_with_exception(self): + def callback(fut): + with self.assertRaisesRegex(Exception, "MSG"): + fut.wait() + raise RuntimeError("MSG") future = rpc.rpc_async( + worker_name((self.rank + 1) % self.world_size), + script_fork_wait_throw, + args=(torch.ones(2),), + ).then(callback) with self.assertRaisesRegex(RuntimeError, "MSG"): + future.wait() @dist_init + def test_call_rpc_with_profiling(self): + if self.rank == 0: + with torch.autograd.profiler.profile() as prof: + prof_key = _build_rpc_profiling_key( + RPCExecMode.ASYNC, + torch._jit_internal._qualified_name(one_arg), + "MSG", + "MSG", + ) + with torch.autograd.profiler.record_function(prof_key) as rf: + ret = call_rpc_with_profiling(rf.handle, "MSG") + events = prof.function_events + function_event = get_function_event(events, prof_key) + self.assertTrue(torch._jit_internal._qualified_name(one_arg) in function_event.name) @dist_init + def test_rpc_async_jit_profiled(self): + if self.rank == 0: + dst_rank = (self.rank + 1) % self.world_size + dst_worker_name = worker_name(dst_rank) + args = (torch.tensor([1, 1]), torch.tensor([2, 2])) + kwargs = {} + with torch.autograd.profiler.profile() as prof: + script_rpc_async_call( + dst_worker_name, args, kwargs + ) function_events = prof.function_events + qual_name = torch._jit_internal._qualified_name(two_args_two_kwargs) + rpc_async_jit_event = [ + event + for event in function_events + if qual_name in event.name and event.node_id == self.rank + ] + self.assertEqual(len(rpc_async_jit_event), 1) + rpc_async_jit_event = rpc_async_jit_event[0] + profiled_name = _build_rpc_profiling_key( + RPCExecMode.ASYNC_JIT, + qual_name, + worker_name(self.rank), + dst_worker_name, + ) + self.assertEqual(profiled_name, rpc_async_jit_event.name) + remote_events = [event for event in function_events if event.is_remote] + remote_event_node_ids = { + remote_event.node_id for remote_event in remote_events + } + self.assertEqual(remote_event_node_ids, {dst_rank}) + remote_add = [ + remote_event + for remote_event in remote_events + if "MSG" in remote_event.name + ][0] + remote_add_profiled_name = f"MSG"foo"MSG" + if self.rank == 0: + dst_rank = (self.rank + 1) % self.world_size + dst_worker_name = worker_name(dst_rank) + block_scope = "MSG" + with torch.autograd.profiler.profile() as prof: + call_rpc_torchscript_with_record_function(dst_worker_name, block_scope) prof.key_averages() + function_events = prof.function_events + expected_key = ( + _build_rpc_profiling_key( + RPCExecMode.ASYNC_JIT, + torch._jit_internal._qualified_name( + script_add_ones_with_record_function + ), + worker_name(self.rank), + dst_worker_name, + ) + + REMOTE_OP_STR + + block_scope + ) + remote_record_function_event = [ + evt for evt in function_events if evt.name == expected_key + ][0] + self.assertTrue(block_scope in remote_record_function_event.name) + remote_children = remote_record_function_event.cpu_children + self.assertTrue("MSG" in child.name for child in remote_children) def test_record_function_jit_end_callbacks_with_fork(self): + sleep_interval = 1 + with torch.autograd.profiler.profile() as prof: + with torch.autograd.profiler.record_function("MSG") as rf: + fut = torch.jit._fork(sleep, sleep_interval) + rf._call_end_callbacks_on_future(fut) + fut.wait() function_events = prof.function_events + sleep_event = get_function_event(function_events, "MSG") + self.assertEqual(sleep_event.name, "MSG") + self.assertGreaterEqual(sleep_event.cpu_time * 1e-6, sleep_interval) def test_call_fork_in_jit_with_profiling(self): + with torch.autograd.profiler.profile() as prof: + with torch.autograd.profiler.record_function("MSG") as rf: + ret = call_fork_with_profiling(rf.handle) events = prof.function_events + function_event = get_function_event(events, "MSG") + self.assertEqual(function_event.name, "MSG") @dist_init + def test_async_function_simple(self): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) ret = rpc.rpc_sync( + dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2)) + ) + self.assertEqual(ret, torch.ones(2, 2) + 1) @dist_init + def test_async_function_wrong_return_type(self): + with self.assertRaisesRegex(RuntimeError, "MSG"): + rpc.rpc_sync( + worker_name((self.rank + 1) % self.world_size), async_wrong_type + ) @dist_init + def test_async_function_wrong_decorator_order(self): + with self.assertRaises(RuntimeError): @torch.jit.script + @rpc.functions.async_execution + def async_wrong_decorator_order( + to: str, x: Tensor, y: Tensor + ) -> Future[Tensor]: + return rpc.rpc_async(to, script_add, (x, y)) @dist_init + def test_async_function_remote(self): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) rref = rpc.remote( + dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2)) + ) + self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1) @dist_init + def test_async_function_remote_multi(self): + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) num = 20 + rrefs = [] + for i in range(num): + rrefs.append( + rpc.remote( + dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2) * i) + ) + ) for i in range(num): + self.assertEqual(rrefs[i].to_here(), torch.ones(2, 2) + i) @dist_init + def test_async_function_wrong_return_type_remote(self): + rref = rpc.remote( + worker_name((self.rank + 1) % self.world_size), async_wrong_type + ) with self.assertRaisesRegex(RuntimeError, "MSG"): + rref.to_here() +from typing import Dict, Tuple import torch +import torch.distributed.rpc as rpc +from torch import Tensor +from torch.testing._internal.dist_utils import ( + dist_init, + worker_name, + wait_until_pending_futures_and_users_flushed +) +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) +@torch.jit.script +def two_args_two_kwargs( + first_arg, + second_arg, + first_kwarg=torch.tensor([3, 3]), + second_kwarg=torch.tensor([4, 4]), +): + return first_arg + second_arg + first_kwarg + second_kwarg +@torch.jit.script +def script_rpc_async_call( + dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor] +): + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) + ret = fut.wait() + return ret +@torch.jit.script +def rpc_async_call_with_timeout( + dst_worker_name: str, + args: Tuple[Tensor, Tensor], + kwargs: Dict[str, Tensor], + timeout: float, +): + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout) + ret = fut.wait() + return ret +@torch.jit.script +def rpc_async_call_with_timeout_future_ret( + dst_worker_name: str, + args: Tuple[Tensor, Tensor], + kwargs: Dict[str, Tensor], + timeout: float, +): + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout) + return fut +@torch.jit.script +def rpc_async_call_future_ret( + dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor] +): + fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) + return fut @torch.jit.script +def rref_to_here(rref_var): + + return rref_var.to_here() @torch.jit.script +def rref_to_here_with_timeout(rref_var, timeout): + + return rref_var.to_here(timeout) @torch.jit.script +def rpc_async_with_rref_arg(dst_worker_name, args): + + fut = rpc.rpc_async(dst_worker_name, rref_to_here, args) + ret = fut.wait() + return ret +class JitFaultyAgentRpcTest(RpcAgentTestFixture): + + @dist_init(faulty_messages=[], messages_to_delay={"MSG": 1.5}) + def test_timeout_in_torchscript_function(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) args = (torch.tensor([1, 1]), torch.tensor([2, 2])) + kwargs = { + "MSG": torch.tensor([2, 2]), + "MSG": torch.tensor([3, 3]), + } + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0.5) rpc._set_rpc_timeout(0.001) + with self.assertRaisesRegex(RuntimeError, expected_error): + script_rpc_async_call( + dst_worker_name, args, kwargs + ) ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0) + self.assertEqual(ret, torch.tensor([8, 8])) + rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) @dist_init(faulty_messages=[], messages_to_delay={"MSG": 1.5}) + def test_timeout_in_python(self): + if self.rank != 0: + return dst_worker_name = worker_name((self.rank + 1) % self.world_size) + args = (torch.tensor([1, 1]), torch.tensor([2, 2])) + kwargs = { + "MSG": torch.tensor([2, 2]), + "MSG": torch.tensor([3, 3]), + } + expected_error = self.get_timeout_error_regex() fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0.5) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() rpc._set_rpc_timeout(0.001) + fut = rpc_async_call_future_ret(dst_worker_name, args, kwargs) + with self.assertRaisesRegex(RuntimeError, expected_error): + fut.wait() fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0) + result = fut.wait() + self.assertEqual(result, torch.tensor([8, 8])) + rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) @dist_init(faulty_messages=["MSG"]) + def test_remote_timeout_to_here_in_jit(self): + if self.rank != 0: + return + dst_rank = (self.rank + 1) % self.world_size + dst_worker = "MSG".format(dst_rank) + rref = rpc.remote( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + wait_until_pending_futures_and_users_flushed() + with self.assertRaisesRegex(RuntimeError, "MSG"): + rref_to_here(rref) @dist_init(faulty_messages=[], messages_to_delay={"MSG": 1}) + def test_rref_to_here_timeout_in_jit(self): + if self.rank != 0: + return dst_rank = (self.rank + 1) % self.world_size + dst_worker = "MSG".format(dst_rank) + rref = rpc.remote( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + rref_to_here_with_timeout(rref, 0.01) rref_to_here_with_timeout(rref, 100) @dist_init(faulty_messages=["MSG"]) + def test_rref_timeout_pickle_in_jit(self): + if self.rank != 0: + return + dst_rank = (self.rank + 1) % self.world_size + dst_worker = "MSG".format(dst_rank) + rref = rpc.remote( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + wait_until_pending_futures_and_users_flushed() + with self.assertRaisesRegex(RuntimeError, "MSG"): + rpc_async_with_rref_arg(dst_worker, (rref, )) @dist_init(faulty_messages=["MSG"]) + def test_rref_timeout_pickle_script_func(self): + if self.rank != 0: + return + dst_rank = (self.rank + 1) % self.world_size + dst_worker = "MSG".format(dst_rank) + rref = rpc.remote( + dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) + ) + wait_until_pending_futures_and_users_flushed() + with self.assertRaisesRegex(RuntimeError, "MSG"): + rpc.rpc_sync(dst_worker, rref_to_here, args=(rref, )) import torch annotated_args = { + torch._C._VariableFunctions._cast_Byte: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cast_Char: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cast_Double: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cast_Float: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cast_Int: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cast_Long: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cast_Short: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cast_Half: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.align_tensors: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._use_cudnn_ctc_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cudnn_ctc_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._use_cudnn_rnn_flatten_weight: [], + torch._C._VariableFunctions._cudnn_rnn_flatten_weight: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cudnn_rnn: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cudnn_init_dropout_state: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._debug_has_internal_overlap: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._fused_dropout: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._masked_scale: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sobol_engine_draw: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sobol_engine_ff_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sobol_engine_scramble_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sobol_engine_initialize_state_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._reshape_from_tensor: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._shape_as_tensor: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.dropout: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.dropout_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.feature_dropout: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.feature_dropout_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.alpha_dropout: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.alpha_dropout_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.feature_alpha_dropout: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.feature_alpha_dropout_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.abs: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.abs: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.abs_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.absolute: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.absolute: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.angle: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.angle: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.view_as_real: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.view_as_complex: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sgn: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sgn: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.real: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.imag: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.conj: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.conj: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._conj: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.acos: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.acos: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.acos_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arccos: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arccos: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arccos_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.avg_pool1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.adaptive_avg_pool1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.adaptive_max_pool1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.add: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.add: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._add_relu: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._add_relu: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._add_relu_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.addmv: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.addmv: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.addmv_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._addmv_impl_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.addr: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.addr: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.affine_grid_generator: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.all: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.all: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.all: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.all: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.all: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.allclose: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.any: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.any: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.any: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.any: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.any: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._dim_arange: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.argmax: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.argmin: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.acosh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.acosh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.acosh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arccosh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arccosh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arccosh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.asinh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.asinh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.asinh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arcsinh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arcsinh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arcsinh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.atanh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.atanh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.atanh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arctanh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arctanh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arctanh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.as_strided: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.as_strided_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.asin: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.asin: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.asin_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arcsin: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arcsin: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arcsin_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.atan: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.atan: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.atan_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arctan: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arctan: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.arctan_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.atleast_1d: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.atleast_1d: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.atleast_2d: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.atleast_2d: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.atleast_3d: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.atleast_3d: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.baddbmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.baddbmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._baddbmm_mkl_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bartlett_window: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bartlett_window: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.batch_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.quantized_batch_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._batch_norm_impl_index: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bernoulli: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bernoulli: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bernoulli: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bilinear: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.binary_cross_entropy_with_logits: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bincount: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bitwise_not: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bitwise_not: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logical_not: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logical_not: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logical_xor: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logical_xor: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logical_and: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logical_and: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logical_or: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logical_or: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.blackman_window: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.blackman_window: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._bmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._bmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.broadcast_tensors: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cat: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cat: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cat: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cat: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.block_diag: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ceil: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ceil: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ceil_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.chain_matmul: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.unsafe_chunk: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.chunk: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.clamp: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.clamp: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.clamp_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.clamp_max: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.clamp_max: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.clamp_max_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.clamp_min: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.clamp_min: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.clamp_min_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.clip: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.clip: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.clip_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cudnn_is_acceptable: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.complex: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.complex: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.polar: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.polar: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.constant_pad_nd: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.convolution: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._convolution: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._convolution: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._convolution_nogroup: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.conv1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.conv2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.conv3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.conv_tbc: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.conv_transpose1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.conv_transpose2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.conv_transpose3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._copy_from: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cos: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cos: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cos_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cosh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cosh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cosh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cosine_embedding_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.count_nonzero: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.count_nonzero: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cudnn_affine_grid_generator: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cudnn_batch_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cudnn_convolution: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cudnn_convolution: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cudnn_convolution: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cudnn_convolution_transpose: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cudnn_convolution_transpose: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cudnn_convolution_transpose: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cudnn_grid_sampler: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cummax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cummax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cummax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cummax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cummax_helper: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cummin: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cummin: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cummin: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cummin: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cummin_helper: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cumprod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cumprod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cumprod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cumprod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cumsum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cumsum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cumsum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cumsum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ctc_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ctc_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._ctc_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.diag_embed: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.diagflat: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.diagonal: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.diagonal: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.div: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.div: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.true_divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.true_divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.true_divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.dot: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.dot: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.vdot: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.vdot: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.einsum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.embedding: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.embedding_renorm_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._embedding_bag_forward_only: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.embedding_bag: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._embedding_bag: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.empty_meta: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.empty: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.empty: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.empty: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._empty_affine_quantized: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._empty_per_channel_affine_quantized: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.empty_quantized: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.empty_like: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.empty_strided: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.erf: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.erf: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.erf_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.erfc: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.erfc: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.erfc_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.exp: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.exp: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.exp_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.exp2: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.exp2: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.exp2_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.expm1: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.expm1: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.expm1_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.eye: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.eye: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.eye: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.eye: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.flatten: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.flatten: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.flatten: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.flatten: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fill_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fill_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.floor: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.floor: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.floor_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.floor_divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.floor_divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.floor_divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.frac: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.frac: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.frac_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.full_like: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.from_file: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.gcd: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.gcd: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.gcd_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lcm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lcm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lcm_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.grid_sampler: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.grid_sampler_2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._grid_sampler_2d_cpu_fallback: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.grid_sampler_3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.hann_window: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.hann_window: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.hamming_window: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.hamming_window: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.hamming_window: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.hamming_window: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.kaiser_window: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.kaiser_window: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.kaiser_window: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.hinge_embedding_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.group_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.native_group_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ifft: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rfft: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.irfft: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._fft_with_size: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._fft_with_size: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cufft_get_plan_cache_size: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cufft_get_plan_cache_max_size: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cufft_set_plan_cache_max_size: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cufft_clear_plan_cache: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.index_copy: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.index_copy: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.index_put_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.index_put: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._index_put_impl_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.instance_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.inverse: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.inverse: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.isclose: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.isnan: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.is_distributed: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.is_floating_point: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.is_complex: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.isreal: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.is_nonzero: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.is_same_size: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.is_signed: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.kl_div: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.kthvalue: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.kthvalue: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.kthvalue: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.kthvalue: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.layer_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.native_layer_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fbgemm_linear_int8_weight_fp32_activation: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fbgemm_linear_int8_weight: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fbgemm_linear_quantize_weight: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fbgemm_pack_gemm_matrix_fp16: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fbgemm_linear_fp16_weight_fp32_activation: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fbgemm_linear_fp16_weight: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fbgemm_pack_quantized_matrix: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fbgemm_pack_quantized_matrix: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.linspace: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.linspace: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.log: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.log: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.log_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.log10: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.log10: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.log10_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.log1p: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.log1p: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.log1p_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.log2: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.log2: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.log2_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logaddexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logaddexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logaddexp2: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logaddexp2: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logdet: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logspace: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logspace: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.log_softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.log_softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._log_softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._log_softmax_backward_data: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._logcumsumexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._logcumsumexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logcumsumexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logcumsumexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logcumsumexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logcumsumexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logsumexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.logsumexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.logsumexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.logsumexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.margin_ranking_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.matmul: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.matmul: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.matrix_rank: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.matrix_rank: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.matrix_power: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.matrix_exp: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._aminmax: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._aminmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._compute_linear_combination: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._compute_linear_combination: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.max: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.max: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.max: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.max: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.max: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.max: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.max: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.amax: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.amax: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.max_pool1d_with_indices: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.max_pool1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.max_pool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._VariableFunctions.mkldnn_max_pool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._VariableFunctions.mkldnn_max_pool3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._VariableFunctions.quantized_max_pool1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.quantized_max_pool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._VariableFunctions.max_pool3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._VariableFunctions.mean: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mean: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.mean: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.mean: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.mean: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.median: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.median: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.median: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.median: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.median: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.min: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.min: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.min: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.min: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.min: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.min: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.min: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.amin: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.amin: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mkldnn_convolution: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mkldnn_convolution_backward_weights: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.miopen_batch_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.miopen_convolution: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.miopen_convolution_transpose: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.miopen_depthwise_convolution: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.miopen_rnn: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sparse_mm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mode: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mode: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mode: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mode: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mul: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mul: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.multiply: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.multiply: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.multiply: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mv: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mv: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mvlgamma: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.narrow: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.narrow: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.native_batch_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.native_batch_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.batch_norm_stats: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.batch_norm_elemt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.batch_norm_elemt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.batch_norm_gather_stats: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.batch_norm_gather_stats_with_counts: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.batch_norm_backward_reduce: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.batch_norm_backward_elemt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.batch_norm_update_stats: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.is_vulkan_available: [], + torch._C._VariableFunctions._nnpack_available: [], + torch._C._VariableFunctions._nnpack_spatial_convolution: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._VariableFunctions.ones: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ones: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ones: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ones_like: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.pairwise_distance: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cdist: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._euclidean_dist: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.pdist: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cosine_similarity: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.movedim: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.movedim: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.pixel_shuffle: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.channel_shuffle: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.pinverse: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.poisson_nll_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rad2deg: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rad2deg: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rad2deg_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.deg2rad: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.deg2rad: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.deg2rad_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.scalar_tensor: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rand: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rand: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rand: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rand: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rand: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rand: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rand_like: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.randint_like: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.randint_like: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.randn: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.randn: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.randn: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.randn: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.randn: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.randn: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.randn_like: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.randperm: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.randperm: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.randperm: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.randperm: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.reciprocal: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.reciprocal: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.reciprocal_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.neg: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.neg: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.neg_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.negative: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.negative: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.negative_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.repeat_interleave: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.repeat_interleave: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.repeat_interleave: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.reshape: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._mkldnn_reshape: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.round: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.round: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.round_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rrelu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rrelu_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.relu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.relu_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.prelu: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.hardshrink: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rsqrt: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rsqrt: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rsqrt_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.select: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.select: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.selu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.selu_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.celu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.celu_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sigmoid: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sigmoid: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sigmoid_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logit: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logit: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.logit_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sin: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sin: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sin_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sinh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sinh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sinh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.detach: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.detach_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.slogdet: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.smm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._softmax_backward_data: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.unsafe_split: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.split: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.unsafe_split_with_sizes: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.split_with_sizes: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.squeeze: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.squeeze: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.squeeze: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sspaddmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sspaddmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.stack: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.stack: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.hstack: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.hstack: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.vstack: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.vstack: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.dstack: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.dstack: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.stft: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.istft: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sum: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.sum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.sum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.sum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.nansum: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.nansum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.nansum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.sqrt: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sqrt: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sqrt_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.square: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.square_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.std: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.std: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.std: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.std: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.std: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.std_mean: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.std_mean: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.std_mean: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.prod: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.prod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.prod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.prod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.prod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.t: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.tan: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.tan: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.tan_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.tanh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.tanh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.tanh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.tensordot: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.threshold: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.threshold: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.threshold_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.transpose: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.transpose: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._mkldnn_transpose: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._mkldnn_transpose_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.flip: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fliplr: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.flipud: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.roll: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.rot90: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.trapz: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.trapz: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._trilinear: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.triplet_margin_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.trunc: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.trunc: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.trunc_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fix: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fix: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fix_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._has_compatible_shallow_copy_type: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._unique: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.unique_dim: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.unique_consecutive: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._unique2: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.unsqueeze: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.vander: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.var: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.var: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.var: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.var: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.var: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.var_mean: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.var_mean: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.var_mean: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.where: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.where: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.where: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.where: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.where: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._s_where: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.norm_except_dim: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._weight_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._weight_norm_cuda_interface: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.zeros: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.zeros: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.zeros: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.zeros_like: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._standard_gamma_grad: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._standard_gamma: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._dirichlet_grad: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sample_dirichlet: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.poisson: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.binomial: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.native_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.native_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sparse_sum: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sparse_sum: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sparse_sum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions._sparse_sum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions._sparse_softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sparse_softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sparse_softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sparse_softmax_backward_data: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sparse_log_softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sparse_log_softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sparse_log_softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sparse_log_softmax_backward_data: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.frobenius_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.frobenius_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.frobenius_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._VariableFunctions.nuclear_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.nuclear_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.nuclear_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._VariableFunctions.nuclear_norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._VariableFunctions.clone: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.resize_as_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.zero_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sub: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sub: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.subtract: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.subtract: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.subtract: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rsub: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rsub: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.heaviside: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.heaviside: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._sparse_addmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.addmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.addmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._validate_sparse_coo_tensor_args: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.hspmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.hspmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.unbind: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.unbind: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.quantize_per_tensor: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.quantize_per_tensor: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.quantize_per_channel: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.dequantize: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.dequantize: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.q_scale: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.q_zero_point: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.q_per_channel_scales: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.q_per_channel_zero_points: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.q_per_channel_axis: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.int_repr: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._make_per_tensor_quantized_tensor: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._make_per_channel_quantized_tensor: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fake_quantize_per_tensor_affine: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._fake_quantize_learnable_per_tensor_affine: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fake_quantize_per_channel_affine: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._fake_quantize_learnable_per_channel_affine: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._choose_qparams_per_tensor: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._saturate_weight_to_fp16: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.choose_qparams_optimized: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.meshgrid: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cartesian_prod: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.combinations: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.result_type: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.result_type: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.result_type: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.result_type: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.can_cast: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.promote_types: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lstm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lstm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.gru: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.gru: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rnn_tanh: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rnn_tanh: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rnn_relu: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rnn_relu: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lstm_cell: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.gru_cell: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rnn_tanh_cell: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.rnn_relu_cell: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.quantized_lstm_cell: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.quantized_gru_cell: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.quantized_rnn_relu_cell: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.quantized_rnn_tanh_cell: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._pack_padded_sequence: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._pad_packed_sequence: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.masked_fill: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.masked_fill: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.masked_scatter: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.index_add: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.index_add: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.index_fill: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.index_fill: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.index_fill: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.index_fill: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.scatter: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.scatter: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.scatter: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.scatter: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.scatter_add: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.scatter_add: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bitwise_and: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bitwise_and: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bitwise_and: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bitwise_and: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.__and__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.__and__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bitwise_or: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bitwise_or: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bitwise_or: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bitwise_or: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.__or__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.__or__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bitwise_xor: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bitwise_xor: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bitwise_xor: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bitwise_xor: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.__xor__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.__xor__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.__lshift__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.__lshift__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.__rshift__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.__rshift__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.addbmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.addbmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.diag: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.diag: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cross: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cross: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.triu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.triu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.tril: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.tril: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.tril_indices: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.triu_indices: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.trace: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ne: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ne: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ne: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ne: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.not_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.not_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.not_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.not_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.eq: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.eq: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.eq: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.eq: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ge: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ge: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ge: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ge: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.greater_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.greater_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.greater_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.greater_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.le: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.le: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.le: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.le: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.less_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.less_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.less_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.less_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.gt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.gt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.gt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.gt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.greater: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.greater: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.greater: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.greater: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.less: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.less: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.less: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.less: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.take: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.take: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.index_select: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.index_select: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.index_select: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.index_select: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.masked_select: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.masked_select: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.gather: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.gather: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.gather: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.gather: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.addcmul: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.addcmul: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.addcdiv: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.addcdiv: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lstsq: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lstsq: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.triangular_solve: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.triangular_solve: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.symeig: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.symeig: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.eig: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.eig: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.svd: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.svd: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cholesky: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cholesky: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cholesky_solve: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cholesky_solve: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.solve: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.solve: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cholesky_inverse: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.cholesky_inverse: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.qr: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.qr: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.geqrf: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.geqrf: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.orgqr: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.orgqr: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ormqr: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ormqr: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._lu_with_info: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lu_solve: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lu_solve: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._lu_solve_helper: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.multinomial: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.multinomial: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._multinomial_alias_setup: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._multinomial_alias_draw: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lgamma: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lgamma: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.digamma: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.digamma: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.polygamma: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.polygamma: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.erfinv: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.erfinv: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.i0: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.i0: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.i0_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sign: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sign: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.signbit: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.signbit: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.dist: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.atan2: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.atan2: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lerp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lerp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lerp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.lerp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.histc: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.histc: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fmod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fmod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fmod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fmod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.hypot: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.hypot: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.nextafter: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.nextafter: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.remainder: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.remainder: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.remainder: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.remainder: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.maximum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.maximum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.minimum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.minimum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.quantile: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.quantile: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.quantile: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.quantile: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.nanquantile: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.nanquantile: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.nanquantile: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.nanquantile: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sort: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sort: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sort: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.sort: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.argsort: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.argsort: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.topk: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.topk: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.renorm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.renorm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.pow: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.pow: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.pow: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.pow: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.pow: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.pow: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.normal: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.normal: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.normal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.normal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.normal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.normal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.normal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.normal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._index_copy_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._var: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._std: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._amp_non_finite_check_and_unscale_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._amp_update_scale: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cat: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._cat: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_add: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_add: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_add_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_add_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_sub: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_sub: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_sub_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_sub_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_mul: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_mul: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_mul_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_mul_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_div: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_div: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_div_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_div_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_add_scalar_list: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_add_scalar_list_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_sub_scalar_list: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_sub_scalar_list_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_div_scalar_list: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_div_scalar_list_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_mul_scalar_list: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_mul_scalar_list_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_exp: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_exp_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_sqrt: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_sqrt_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_addcdiv_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_addcmul_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_addcdiv: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._foreach_addcmul: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._mode: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._mode: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bucketize: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bucketize: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.bucketize: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.searchsorted: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.searchsorted: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.searchsorted: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.mkldnn_adaptive_avg_pool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._VariableFunctions._adaptive_avg_pool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._VariableFunctions.isfinite: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.isinf: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.isposinf: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.isposinf: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.isneginf: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.isneginf: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._add_batch_dim: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._remove_batch_dim: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.fft: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.det: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.outer: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.outer: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ger: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions.ger: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._VariableFunctions._test_serialization_subcmul: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.binary_cross_entropy: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.binary_cross_entropy: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.linear: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.mkldnn_linear: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.gelu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.silu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.silu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.silu_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.one_hot: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.mkldnn_reorder_conv2d_weight: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.mkldnn_reorder_conv3d_weight: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.mse_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.mse_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.l1_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.l1_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.multi_margin_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.multi_margin_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.multilabel_margin_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.multilabel_margin_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.nll_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.nll_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.nll_loss2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.nll_loss2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.smooth_l1_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.smooth_l1_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.soft_margin_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.soft_margin_loss: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.elu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.elu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.elu_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.glu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.glu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.hardsigmoid: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.hardsigmoid: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.hardsigmoid_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.hardtanh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.hardtanh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.hardtanh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.hardswish: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.hardswish: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.hardswish_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.leaky_relu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.leaky_relu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.leaky_relu_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.log_sigmoid: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.log_sigmoid: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.rrelu_with_noise: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.rrelu_with_noise: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.rrelu_with_noise_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.softplus: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.softplus: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.softshrink: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.softshrink: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.adaptive_avg_pool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.adaptive_avg_pool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.adaptive_avg_pool3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.adaptive_avg_pool3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.adaptive_max_pool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.adaptive_max_pool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.adaptive_max_pool3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.adaptive_max_pool3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.avg_pool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.avg_pool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.avg_pool3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.avg_pool3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.fractional_max_pool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.fractional_max_pool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.fractional_max_pool3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.fractional_max_pool3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.max_pool2d_with_indices: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.max_pool2d_with_indices: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.max_pool3d_with_indices: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.max_pool3d_with_indices: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.max_unpool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.max_unpool2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.max_unpool3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.max_unpool3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.reflection_pad1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.reflection_pad1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.reflection_pad2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 4}], + torch._C._nn.reflection_pad2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 4}], + torch._C._nn.replication_pad1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.replication_pad1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.replication_pad2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 4}], + torch._C._nn.replication_pad2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 4}], + torch._C._nn.replication_pad3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 6}], + torch._C._nn.replication_pad3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 6}], + torch._C._nn.upsample_linear1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_linear1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_linear1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_bilinear2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_bilinear2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_bilinear2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_trilinear3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_trilinear3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_trilinear3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_bicubic2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_bicubic2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_bicubic2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_nearest1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_nearest1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._nn.upsample_nearest1d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch._C._nn.upsample_nearest2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_nearest2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.upsample_nearest2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.upsample_nearest3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn.upsample_nearest3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.upsample_nearest3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.slow_conv_transpose2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.slow_conv_transpose2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.slow_conv_transpose3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.slow_conv_transpose3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.thnn_conv2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.thnn_conv2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.thnn_conv_depthwise2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.thnn_conv_depthwise2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.slow_conv3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.slow_conv3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.slow_conv_dilated2d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.slow_conv_dilated3d: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 3}], + torch._C._nn.col2im: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.col2im: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.im2col: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn.im2col: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn._test_optional_intlist: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch._C._nn._test_optional_filled_intlist: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 2}], + torch._C._nn._test_optional_floatlist: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.backward: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.rename_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.rename: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.align_to: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.align_to: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.align_as: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.refine_names: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.abs: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.abs_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.absolute: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.absolute_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.angle: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sgn: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sgn_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.conj: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.acos: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.acos_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.arccos: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.arccos_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.add: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.add_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.addmv: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.addmv_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.addr: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.addr_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.all: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.all: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.all: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.allclose: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.any: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.any: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.any: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.argmax: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.argmin: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.acosh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.acosh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.arccosh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.arccosh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.asinh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.asinh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.arcsinh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.arcsinh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.atanh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.atanh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.arctanh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.arctanh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.as_strided: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.as_strided_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.asin: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.asin_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.arcsin: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.arcsin_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.atan: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.atan_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.arctan: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.arctan_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.baddbmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.baddbmm_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bernoulli: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bernoulli: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bernoulli_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bernoulli_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bincount: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bitwise_not: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bitwise_not_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logical_not: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logical_not_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logical_xor: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logical_xor_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logical_and: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logical_and_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logical_or: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logical_or_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.ceil: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.ceil_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.unsafe_chunk: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.chunk: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.clamp: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.clamp_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.clamp_max: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.clamp_max_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.clamp_min: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.clamp_min_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.clip: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.clip_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cos: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cos_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cosh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cosh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.count_nonzero: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.count_nonzero: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cummax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cummax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cummin: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cummin: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cumprod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cumprod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cumsum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cumsum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.diag_embed: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.diagflat: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.diagonal: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.diagonal: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.fill_diagonal_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.div: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.div_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.divide_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.divide_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.true_divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.true_divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.true_divide_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.true_divide_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.dot: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.vdot: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.new_empty: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.new_full: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.new_zeros: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.resize_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.erf: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.erf_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.erfc: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.erfc_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.exp: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.exp_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.exp2: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.exp2_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.expm1: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.expm1_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.expand: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.expand_as: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.flatten: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.flatten: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.flatten: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.flatten: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.unflatten: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.unflatten: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.fill_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.fill_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.floor: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.floor_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.floor_divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.floor_divide: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.floor_divide_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.floor_divide_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.frac: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.frac_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.gcd: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.gcd_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.lcm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.lcm_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.ifft: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.rfft: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.irfft: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_copy_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_copy_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_copy: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_copy: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_put_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_put: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.inverse: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.isclose: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.isnan: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.is_distributed: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.is_floating_point: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.is_complex: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.isreal: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.is_nonzero: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.is_same_size: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.is_signed: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.kthvalue: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.kthvalue: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.log: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.log_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.log10: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.log10_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.log1p: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.log1p_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.log2: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.log2_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logaddexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logaddexp2: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logdet: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.log_softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.log_softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logcumsumexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logcumsumexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logsumexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch.Tensor.logsumexp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch.Tensor.matmul: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.matrix_power: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.matrix_exp: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.max: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.max: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.max: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.max: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.amax: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.mean: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.mean: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch.Tensor.mean: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch.Tensor.median: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.median: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.median: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.min: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.min: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.min: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.min: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.amin: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.mm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.mode: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.mode: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.mul: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.mul_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.multiply: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.multiply: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.multiply_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.multiply_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.mv: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.mvlgamma: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.mvlgamma_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.narrow_copy: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.narrow: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.narrow: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.permute: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.movedim: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.movedim: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.is_pinned: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.pin_memory: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.pinverse: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.rad2deg: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.rad2deg_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.deg2rad: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.deg2rad_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.reciprocal: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.reciprocal_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.neg: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.neg_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.negative: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.negative_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.repeat: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.repeat_interleave: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.repeat_interleave: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.reshape: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.reshape_as: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.round: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.round_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.relu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.relu_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.prelu: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.hardshrink: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.rsqrt: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.rsqrt_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.select: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.select: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sigmoid: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sigmoid_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logit: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.logit_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sin: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sin_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sinh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sinh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.detach: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.detach_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.slogdet: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.smm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.softmax: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.unsafe_split: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.split: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.unsafe_split_with_sizes: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.split_with_sizes: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.squeeze: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.squeeze: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.squeeze: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.squeeze_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.squeeze_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.squeeze_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sspaddmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.stft: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.istft: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sum: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch.Tensor.sum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch.Tensor.nansum: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.nansum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch.Tensor.sum_to_size: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sqrt: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sqrt_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.square: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.square_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.std: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.std: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch.Tensor.std: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch.Tensor.prod: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.prod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.prod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.t: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.t_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.tan: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.tan_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.tanh: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.tanh_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.transpose: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.transpose: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.transpose_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.flip: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.fliplr: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.flipud: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.roll: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch.Tensor.rot90: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.trunc: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.trunc_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.fix: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.fix_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.type_as: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.unsqueeze: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.unsqueeze_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.var: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.var: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch.Tensor.var: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch.Tensor.view_as: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.where: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch.Tensor.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.norm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG', 'MSG': 1}], + torch.Tensor.clone: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.resize_as_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.zero_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sub: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sub_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.subtract: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.subtract: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.subtract_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.subtract_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.heaviside: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.heaviside_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.addmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.addmm_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sparse_resize_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sparse_resize_and_clear_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sparse_mask: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.to_dense: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sparse_dim: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor._dimI: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.dense_dim: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor._dimV: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor._nnz: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.coalesce: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.is_coalesced: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor._indices: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor._values: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor._coalesced_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.indices: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.values: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.unbind: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.unbind: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.to_sparse: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.to_sparse: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.to_mkldnn: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.dequantize: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.q_scale: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.q_zero_point: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.q_per_channel_scales: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.q_per_channel_zero_points: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.q_per_channel_axis: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.int_repr: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.qscheme: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.is_set_to: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.masked_fill_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.masked_fill_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.masked_fill: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.masked_fill: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.masked_scatter_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.masked_scatter: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.view: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.put_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_add_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_add: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_add: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_fill_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_fill_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_fill_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_fill_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_fill: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_fill: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_fill: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_fill: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.scatter_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.scatter_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.scatter_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.scatter_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.scatter: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.scatter: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.scatter: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.scatter: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.scatter_add_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.scatter_add: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.scatter_add: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.eq_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.eq_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bitwise_and: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bitwise_and: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bitwise_and_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bitwise_and_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__and__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__and__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__iand__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__iand__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bitwise_or: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bitwise_or: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bitwise_or_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bitwise_or_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__or__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__or__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__ior__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__ior__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bitwise_xor: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bitwise_xor: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bitwise_xor_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.bitwise_xor_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__xor__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__xor__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__ixor__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__ixor__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__lshift__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__lshift__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__ilshift__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__ilshift__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__rshift__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__rshift__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__irshift__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.__irshift__: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.lgamma_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.atan2_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.tril_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.triu_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.digamma_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.polygamma_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.renorm_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.pow_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.pow_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.lerp_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.lerp_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.fmod_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.fmod_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.remainder_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.remainder_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.addbmm_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.addbmm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.addcdiv_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.random_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.random_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.random_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.uniform_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cauchy_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.log_normal_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.exponential_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.geometric_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.diag: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cross: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.triu: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.tril: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.trace: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.ne: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.ne: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.ne_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.ne_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.not_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.not_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.not_equal_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.not_equal_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.eq: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.eq: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.ge: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.ge: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.ge_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.ge_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.greater_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.greater_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.greater_equal_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.greater_equal_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.le: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.le: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.le_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.le_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.less_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.less_equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.less_equal_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.less_equal_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.gt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.gt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.gt_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.gt_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.greater: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.greater: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.greater_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.greater_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.lt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.lt: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.lt_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.lt_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.less: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.less: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.less_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.less_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.take: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_select: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.index_select: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.masked_select: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.gather: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.gather: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.addcmul: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.addcmul_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.addcdiv: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.lstsq: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.triangular_solve: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.symeig: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.eig: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.svd: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cholesky: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cholesky_solve: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.solve: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.cholesky_inverse: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.qr: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.geqrf: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.orgqr: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.ormqr: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.lu_solve: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.multinomial: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.lgamma: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.digamma: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.polygamma: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.erfinv: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.erfinv_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.i0: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.i0_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sign: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sign_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.signbit: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.dist: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.atan2: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.lerp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.lerp: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.histc: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.fmod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.fmod: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.hypot: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.hypot_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.nextafter: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.nextafter_: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.remainder: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.remainder: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.maximum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.minimum: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.quantile: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.quantile: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.nanquantile: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.nanquantile: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sort: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.sort: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.argsort: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.argsort: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.topk: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.renorm: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.unfold: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.equal: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.pow: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.pow: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.normal_: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.isfinite: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.isinf: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.isposinf: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.isneginf: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.fft: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.det: [{'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.outer: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], + torch.Tensor.ger: [{'MSG': 'MSG', 'MSG': 'MSG'}, {'MSG': 'MSG', 'MSG': 'MSG'}], +} +from __future__ import division +import torch +def div_int_future(): + return 1 / 2 +def div_float_future(): + return 3.14 / 0.125 +import torch +def div_int_nofuture(): + return 1 / 2 +def div_float_nofuture(): + return 3.14 / 0.125 from typing import Any, TypeVar, Optional, Tuple, List, NamedTuple, Union +import textwrap +import torch +from torch._C import TupleType, OptionalType, ListType +T = TypeVar("MSG") MAX_RAW_TENSOR_SIZE = 16 +class InflatableArg(NamedTuple): + value: Any + fmt: str +def augment_model_with_bundled_inputs( + model: torch.jit.ScriptModule, + inputs: Optional[List[Tuple[Any, ...]]] = None, + _receive_inflate_expr: Optional[List[str]] = None, +) -> None: + + if not isinstance(model, torch.jit.ScriptModule): + raise Exception("MSG") forward_arg_types = [arg.type for arg in model.forward.schema.arguments[1:]] + deflated_inputs_type: ListType = ListType(TupleType(forward_arg_types)) + inflated_inputs_type: OptionalType[ListType] = OptionalType(deflated_inputs_type) + model._c._register_attribute("MSG", deflated_inputs_type, []) + model._c._register_attribute("MSG", inflated_inputs_type, None) if hasattr(model, "MSG"): + if inputs is not None: + raise Exception( + "MSG") + elif inputs is None: + raise Exception( + "MSG") + else: + deflated_inputs = [] + parts = [] + for inp_idx, args in enumerate(inputs): + deflated_args = [] + parts.append("MSG") + for arg_idx, arg in enumerate(args): + deflated, inflater = _inflate_expr(arg, f"MSG") + deflated_args.append(deflated) + parts.append(f"MSG") + deflated_inputs.append(tuple(deflated_args)) + parts.append("MSG") + parts.append("MSG") + expr = "MSG".join(parts) + if _receive_inflate_expr is not None: + _receive_inflate_expr.append(expr) + model._bundled_inputs_deflated = deflated_inputs + definition = textwrap.dedent().format(expr) + model.define(definition) + model.define(textwrap.dedent()) + model.define(textwrap.dedent()) + model.define(textwrap.dedent()) +def _inflate_expr(arg: T, ref: str) -> Tuple[Union[T, torch.Tensor], str]: + + + + if isinstance(arg, InflatableArg): + return arg.value, arg.fmt.format(ref) if isinstance(arg, torch.Tensor): + if arg.storage().size() <= MAX_RAW_TENSOR_SIZE: + return arg, ref + if arg.is_contiguous() and arg.numel() <= MAX_RAW_TENSOR_SIZE: + return arg.clone(), ref + for fmt in [torch.contiguous_format, torch.channels_last]: + if arg.is_contiguous(memory_format=fmt) and (arg == arg.flatten()[0]).all().item(): + return (torch.tensor([arg.flatten()[0]]).expand(*arg.size()), + f"MSG") + raise Exception( + f"MSG" + f"MSG" + f"MSG" + ) + else: + return arg, ref +def bundle_randn(*size, dtype=None): + + stub = torch.zeros(1, dtype=dtype).expand(*size) + return InflatableArg(value=stub, fmt="MSG") +def bundle_large_tensor(t): + + return InflatableArg(value=t, fmt="MSG") +import torch +import warnings +from typing import Any, Iterable, List, Tuple +def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: + if isinstance(inputs, tuple): + out = [] + for inp in inputs: + if not isinstance(inp, torch.Tensor): + out.append(inp) + continue x = inp.detach() + x.requires_grad = inp.requires_grad + out.append(x) + return tuple(out) + else: + raise RuntimeError( + "MSG", type(inputs).__name__) +def check_backward_validity(inputs: Iterable[Any]) -> None: + if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): + warnings.warn("MSG") def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: + + + fwd_gpu_devices = list(set(arg.get_device() for arg in args + if isinstance(arg, torch.Tensor) and arg.is_cuda)) fwd_gpu_states = [] + for device in fwd_gpu_devices: + with torch.cuda.device(device): + fwd_gpu_states.append(torch.cuda.get_rng_state()) return fwd_gpu_devices, fwd_gpu_states +def set_device_states(devices, states) -> None: + for device, state in zip(devices, states): + with torch.cuda.device(device): + torch.cuda.set_rng_state(state) +class CheckpointFunction(torch.autograd.Function): @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + check_backward_validity(args) + ctx.run_function = run_function + ctx.preserve_rng_state = preserve_rng_state + if preserve_rng_state: + ctx.fwd_cpu_state = torch.get_rng_state() + ctx.had_cuda_in_fwd = False + if torch.cuda._initialized: + ctx.had_cuda_in_fwd = True + ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) + ctx.save_for_backward(*args) + with torch.no_grad(): + outputs = run_function(*args) + return outputs @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError("MSG") + inputs = ctx.saved_tensors + rng_devices = [] + if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: + rng_devices = ctx.fwd_gpu_devices + with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): + if ctx.preserve_rng_state: + torch.set_rng_state(ctx.fwd_cpu_state) + if ctx.had_cuda_in_fwd: + set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) + detached_inputs = detach_variable(inputs) + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + torch.autograd.backward(outputs, args) + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp + for inp in detached_inputs) + return (None, None) + grads +def checkpoint(function, *args, **kwargs): + r + + preserve = kwargs.pop('MSG', True) + if kwargs: + raise ValueError("MSG" + "MSG".join(arg for arg in kwargs)) return CheckpointFunction.apply(function, preserve, *args) +def checkpoint_sequential(functions, segments, input, **kwargs): + r + + preserve = kwargs.pop('MSG', True) + if kwargs: + raise ValueError("MSG" + "MSG".join(arg for arg in kwargs)) def run_function(start, end, functions): + def forward(input): + for j in range(start, end + 1): + input = functions[j](input) + return input + return forward if isinstance(functions, torch.nn.Sequential): + functions = list(functions.children()) segment_size = len(functions) // segments + + end = -1 + for start in range(0, segment_size * (segments - 1), segment_size): + end = start + segment_size - 1 + input = checkpoint(run_function(start, end, functions), input, + preserve_rng_state=preserve) + return run_function(end + 1, len(functions) - 1, functions)(input) +import locale +import re +import subprocess +import sys +import os +from collections import namedtuple try: + import torch + TORCH_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + TORCH_AVAILABLE = False +SystemEnv = namedtuple('MSG', [ + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', + 'MSG', +]) +def run(command): + + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=True) + raw_output, raw_err = p.communicate() + rc = p.returncode + enc = locale.getpreferredencoding() + output = raw_output.decode(enc) + err = raw_err.decode(enc) + return rc, output.strip(), err.strip() +def run_and_read_all(run_lambda, command): + + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out +def run_and_parse_first_match(run_lambda, command, regex): + + rc, out, _ = run_lambda(command) + if rc != 0: + return None + match = re.search(regex, out) + if match is None: + return None + return match.group(1) +def get_conda_packages(run_lambda): + if get_platform() == 'MSG': + system_root = os.environ.get('MSG', 'MSG') + findstr_cmd = os.path.join(system_root, 'MSG', 'MSG') + grep_cmd = r'MSG'.format(findstr_cmd) + else: + grep_cmd = r'MSG' + conda = os.environ.get('MSG', 'MSG') + out = run_and_read_all(run_lambda, conda + 'MSG' + grep_cmd) + if out is None: + return out + + + if functions is not None: + module_def = [] + module_def.append('MSG') + if isinstance(functions, str): + functions = [functions] + if isinstance(functions, list): + functions = dict((f, f) for f in functions) + elif not isinstance(functions, dict): + raise ValueError( + "MSG".format( + type(functions))) + for function_name, docstring in functions.items(): + if with_pytorch_error_handling: + module_def.append( + 'MSG' + .format(function_name, docstring)) + else: + module_def.append('MSG'.format(function_name, docstring)) + module_def.append('MSG') + cpp_sources += module_def cpp_source_path = os.path.join(build_directory, 'MSG') + with open(cpp_source_path, 'MSG') as cpp_source_file: + cpp_source_file.write('MSG'.join(cpp_sources)) sources = [cpp_source_path] if cuda_sources: + cuda_sources.insert(0, 'MSG' + if not os.path.isdir(filename): + files.append(filename[len(self.directory) + 1:]) + return files from .importer import PackageImporter +from .exporter import PackageExporter diff --git a/model.py b/model.py index d07caae..ef9d8c3 100644 --- a/model.py +++ b/model.py @@ -1,6 +1,7 @@ # coding: utf-8 from __future__ import print_function -import tensorflow as tf +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() import numpy as np import time import os @@ -23,7 +24,7 @@ def __init__(self, num_classes, num_seqs=64, num_steps=50, num_seqs, num_steps = 1, 1 else: num_seqs, num_steps = num_seqs, num_steps - + self.step = 0 self.num_classes = num_classes self.num_seqs = num_seqs self.num_steps = num_steps @@ -42,6 +43,12 @@ def __init__(self, num_classes, num_seqs=64, num_steps=50, self.build_optimizer() self.saver = tf.train.Saver() + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + self.session = tf.Session(config=config) + + self.restored = False + def build_inputs(self): with tf.name_scope('inputs'): self.inputs = tf.placeholder(tf.int32, shape=( @@ -101,14 +108,13 @@ def build_optimizer(self): self.optimizer = train_op.apply_gradients(zip(grads, tvars)) def train(self, batch_generator, max_steps, save_path, save_every_n, log_every_n): - self.session = tf.Session() with self.session as sess: - sess.run(tf.global_variables_initializer()) + if not self.restored: + sess.run(tf.global_variables_initializer()) # Train network - step = 0 new_state = sess.run(self.initial_state) for x, y in batch_generator: - step += 1 + self.step += 1 start = time.time() feed = {self.inputs: x, self.targets: y, @@ -121,15 +127,15 @@ def train(self, batch_generator, max_steps, save_path, save_every_n, log_every_n end = time.time() # control the print lines - if step % log_every_n == 0: - print('step: {}/{}... '.format(step, max_steps), + if self.step % log_every_n == 0: + print('step: {}/{}... '.format(self.step, max_steps), 'loss: {:.4f}... '.format(batch_loss), '{:.4f} sec/batch'.format((end - start))) - if (step % save_every_n == 0): - self.saver.save(sess, os.path.join(save_path, 'model'), global_step=step) - if step >= max_steps: + if (self.step % save_every_n == 0): + self.saver.save(sess, os.path.join(save_path, 'model'), global_step=self.step) + if self.step >= max_steps: break - self.saver.save(sess, os.path.join(save_path, 'model'), global_step=step) + self.saver.save(sess, os.path.join(save_path, 'model'), global_step=self.step) def sample(self, n_samples, prime, vocab_size): samples = [c for c in prime] @@ -169,3 +175,4 @@ def load(self, checkpoint): self.session = tf.Session() self.saver.restore(self.session, checkpoint) print('Restored from: {}'.format(checkpoint)) + self.restored = True diff --git a/out.txt b/out.txt new file mode 100644 index 0000000000000000000000000000000000000000..22f1b4aa36a15cf8722d43846193a853e26a47b6 GIT binary patch literal 672 zcma))O-chn5QX1*qIVdwA2S+sk%gOp%OF8qXe7iWI>Gs4GKt287jW-6TzdeImH1T) z+F*pxP!#oEe^oU-oAtf|b=A>8siqpKQchlKMeHhZl}LT{IFZg(;L5gY7bv;rI4Ye% zEA*&SN2=izXd&u|+okJu;Ixox2szT&d6rNr)h-TkjK~ul8R~>;s+O+YTm^A~{-hb) zgxK1`ur3eW^7|xk##Bp~Gv=!0ciOGjFKBnm$-rEH6%kt)Wec6bEo@rkY Ckay|; literal 0 HcmV?d00001 diff --git a/read_utils.py b/read_utils.py index ee17b09..6f2b22f 100644 --- a/read_utils.py +++ b/read_utils.py @@ -1,7 +1,8 @@ import numpy as np import copy import time -import tensorflow as tf +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() import pickle diff --git a/sample.py b/sample.py index fdfe0fb..14c57d1 100644 --- a/sample.py +++ b/sample.py @@ -1,4 +1,5 @@ -import tensorflow as tf +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() from read_utils import TextConverter from model import CharRNN import os @@ -17,7 +18,7 @@ def main(_): - FLAGS.start_string = FLAGS.start_string.decode('utf-8') + FLAGS.start_string = FLAGS.start_string converter = TextConverter(filename=FLAGS.converter_path) if os.path.isdir(FLAGS.checkpoint_path): FLAGS.checkpoint_path =\ diff --git a/train.py b/train.py index c114d8c..01635e4 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,5 @@ -import tensorflow as tf +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() from read_utils import TextConverter, batch_generator from model import CharRNN import os @@ -24,8 +25,12 @@ def main(_): model_path = os.path.join('model', FLAGS.name) + print(model_path) if os.path.exists(model_path) is False: os.makedirs(model_path) + path_exist = False + else: + path_exist = True with codecs.open(FLAGS.input_file, encoding='utf-8') as f: text = f.read() converter = TextConverter(text, FLAGS.max_vocab) @@ -44,6 +49,19 @@ def main(_): use_embedding=FLAGS.use_embedding, embedding_size=FLAGS.embedding_size ) + model_file_path = tf.train.latest_checkpoint(model_path) + if path_exist: + model.load(model_file_path) + indexes = [] + for dirpath, dirnames, filenames in os.walk(model_path): + for name in filenames: + filepath = os.path.join(dirpath, name) + if filepath.endswith(".index"): + indexes.append(int(name[6:-6])) + indexes.sort() + last_index = indexes[-1] + model.step = last_index + model.train(g, FLAGS.max_steps, model_path, From 04c9d15cf9b5fea63dcedaea1fbd9ca2ac3007ad Mon Sep 17 00:00:00 2001 From: LZY2006 <58110034+Lee-qian-gay@users.noreply.github.com> Date: Sun, 10 Jan 2021 12:40:47 +0800 Subject: [PATCH 02/12] =?UTF-8?q?Signed-off-by:=20LZY=20=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E8=87=AA=E8=BF=B0=E6=96=87=E6=A1=A3=20updated=20readme?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index c45f7c3..a9fc4ea 100644 --- a/README.md +++ b/README.md @@ -3,13 +3,14 @@ Multi-language Char RNN in TensorFlow. You can use this code to generate English text, Chinese poetries and lyrics, Japanese text and text in other language. 一个基于最新版本TensorFlow的Char RNN实现。可以实现生成英文、写诗、歌词、小说、生成代码、生成日文等功能。 +LZY 在 HYZ 的基础上做了一些微小的改动来让此支持断点继续训练。 -## Requirements +## Requirements 环境要求 - Python 2.7.X - TensorFlow >= 1.2 -## Generate English Text +## Generate English Text 生成英文 To train: @@ -52,7 +53,7 @@ I the camples. ``` -## Generate Chinese Poetries +## Generate Chinese Poetries 生成中文诗 To train: @@ -88,7 +89,7 @@ Result: 何时有相访,不得在君心。 ``` -## Generate Chinese Novels +## Generate Chinese Novels 生成中文小说 To train (The file "novel.txt" is not included in this repo. You should find one and make sure it is utf-8 encoded!): ``` @@ -130,7 +131,7 @@ Result: “嗤!” ``` -## Generate Chinese Lyrics +## Generate Chinese Lyrics 生成中文歌词 To train: @@ -172,7 +173,7 @@ Result: 我们 你的我 你不会再会爱不到 ``` -## Generate Linux Code +## Generate Linux Code 生成 Linux 代码 To train: @@ -224,7 +225,7 @@ int print_init(struct priority *rt) } ``` -## Generate Japanese Text +## Generate Japanese Text 生成日文 To train: ``` From f4989562f07c5522e44849e723d9d4b4dd101261 Mon Sep 17 00:00:00 2001 From: LZY2006 <58110034+Lee-qian-gay@users.noreply.github.com> Date: Sun, 10 Jan 2021 12:42:35 +0800 Subject: [PATCH 03/12] =?UTF-8?q?=E5=88=A0=E9=99=A4=E4=B8=8D=E5=BF=85?= =?UTF-8?q?=E8=A6=81=E7=9A=84=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- out.txt | Bin 672 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 out.txt diff --git a/out.txt b/out.txt deleted file mode 100644 index 22f1b4aa36a15cf8722d43846193a853e26a47b6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 672 zcma))O-chn5QX1*qIVdwA2S+sk%gOp%OF8qXe7iWI>Gs4GKt287jW-6TzdeImH1T) z+F*pxP!#oEe^oU-oAtf|b=A>8siqpKQchlKMeHhZl}LT{IFZg(;L5gY7bv;rI4Ye% zEA*&SN2=izXd&u|+okJu;Ixox2szT&d6rNr)h-TkjK~ul8R~>;s+O+YTm^A~{-hb) zgxK1`ur3eW^7|xk##Bp~Gv=!0ciOGjFKBnm$-rEH6%kt)Wec6bEo@rkY Ckay|; From ad10686d7172aaf7a4a9d7ec853a38ec77ad0ea6 Mon Sep 17 00:00:00 2001 From: LZY2006 <58110034+Lee-qian-gay@users.noreply.github.com> Date: Sun, 10 Jan 2021 12:50:21 +0800 Subject: [PATCH 04/12] =?UTF-8?q?Signed-off-by:=20LZY=20=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E4=BA=86=E8=87=AA=E8=BF=B0=E6=96=87=E4=BB=B6=20updated=20readm?= =?UTF-8?q?e?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a9fc4ea..f5a8402 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ Multi-language Char RNN in TensorFlow. You can use this code to generate English text, Chinese poetries and lyrics, Japanese text and text in other language. 一个基于最新版本TensorFlow的Char RNN实现。可以实现生成英文、写诗、歌词、小说、生成代码、生成日文等功能。 -LZY 在 HYZ 的基础上做了一些微小的改动来让此支持断点继续训练。 +LZY 在 HYZ 的基础上做了一些微小的改动来让此支持断点继续训练和Tensorflow 2.0。 ## Requirements 环境要求 From bbf615ca8a68a7eb8d7ca0237137216b86a0db80 Mon Sep 17 00:00:00 2001 From: LZY2006 <58110034+Lee-qian-gay@users.noreply.github.com> Date: Sun, 10 Jan 2021 17:42:54 +0800 Subject: [PATCH 05/12] =?UTF-8?q?Signed-off-by:=20LZY=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E4=BA=86predict=E5=8A=9F=E8=83=BD=EF=BC=8C=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E4=BA=86=E8=87=AA=E8=BF=B0=E6=96=87=E6=A1=A3=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 5 +++- .vscode/settings.json | 3 ++ README.md | 11 ++++++++ model.py | 65 +++++++++++++++++++++++++++++++++++++++++++ predict.py | 60 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 .vscode/settings.json create mode 100644 predict.py diff --git a/.gitignore b/.gitignore index d74fccd..9f90e85 100644 --- a/.gitignore +++ b/.gitignore @@ -109,4 +109,7 @@ training/ data/Bilibili.txt data/TheThreeBodyProblem.txt data/WangZengQi.txt -data/ZhaoHuaXiShi.txt \ No newline at end of file +data/ZhaoHuaXiShi.txt + +# vscode +.vscode/ \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..30e6c45 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.pythonPath": "C:\\Users\\zdwxx\\AppData\\Local\\Programs\\Python\\Python37\\python.exe" +} \ No newline at end of file diff --git a/README.md b/README.md index f5a8402..614d07d 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,17 @@ LZY 在 HYZ 的基础上做了一些微小的改动来让此支持断点继续 - Python 2.7.X - TensorFlow >= 1.2 +## How to use predict.py 预测字符的方法 +for example: 例如: +``` +python predict.py \ + --converter_path model/shakespeare/converter.pkl \ + --checkpoint_path model/shakespeare/ \ + --max_length 1000 +``` +使用和 sample.py 同样的参数。 +Use the same parameters as sample.py. + ## Generate English Text 生成英文 To train: diff --git a/model.py b/model.py index ef9d8c3..eccbc5d 100644 --- a/model.py +++ b/model.py @@ -176,3 +176,68 @@ def load(self, checkpoint): self.saver.restore(self.session, checkpoint) print('Restored from: {}'.format(checkpoint)) self.restored = True + + def predict(self, n_samples, prime, vocab_size, depth=5): + samples = [c for c in prime] + sess = self.session + new_state = sess.run(self.initial_state) + preds = np.ones((vocab_size, )) # for prime=[] + for c in prime: + x = np.zeros((1, 1)) + # 输入单个字符 + x[0, 0] = c + feed = {self.inputs: x, + self.keep_prob: 1., + self.initial_state: new_state} + preds, new_state = sess.run([self.proba_prediction, self.final_state], + feed_dict=feed) + + c = pick_top_n(preds, vocab_size) + # 添加字符到samples中 + samples.append(c) + + # 不断生成字符,直到达到指定数目 + x = np.zeros((1, 1)) + x[0, 0] = c + feed = {self.inputs: x, + self.keep_prob: 1., + self.initial_state: new_state} + preds, new_state = sess.run([self.proba_prediction, self.final_state], + feed_dict=feed) + + p = preds.copy() + p = p.reshape([p.shape[1]]) + c = np.argsort(-p)[:5] + p.sort() + p = p[::-1][:5] + p = p / np.sum(p) + top = [c, p] + + result = [] + for i in range(5): + c = top[0][i] + p = top[1][i] + + x = np.zeros((1, 1)) + x[0, 0] = c + feed = {self.inputs: x, + self.keep_prob: 1., + self.initial_state: new_state} + preds, new_state = sess.run([self.proba_prediction, self.final_state], + feed_dict=feed) + + generated = [c, ] + for i in range(depth): + x = np.zeros((1, 1)) + x[0, 0] = c + feed = {self.inputs: x, + self.keep_prob: 1., + self.initial_state: new_state} + preds, new_state = sess.run([self.proba_prediction, self.final_state], + feed_dict=feed) + + c = pick_top_n(preds, vocab_size, 1) + generated.append(c) + result.append([generated, p]) + + return result \ No newline at end of file diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..3718827 --- /dev/null +++ b/predict.py @@ -0,0 +1,60 @@ +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() +from read_utils import TextConverter +from model import CharRNN +import os +from IPython import embed + +import warnings +warnings.filterwarnings('ignore') + +FLAGS = tf.flags.FLAGS + +tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm') +tf.flags.DEFINE_integer('num_layers', 2, 'number of lstm layers') +tf.flags.DEFINE_boolean('use_embedding', False, 'whether to use embedding') +tf.flags.DEFINE_integer('embedding_size', 128, 'size of embedding') +tf.flags.DEFINE_string('converter_path', '', 'model/name/converter.pkl') +tf.flags.DEFINE_string('checkpoint_path', '', 'checkpoint path') +tf.flags.DEFINE_string('start_string', '', 'use this string to start generating') +tf.flags.DEFINE_integer('max_length', 30, 'max length to generate') + +def remove_return(s): + result = [] + for i in s: + if i == "\r": + result.append("\\r") + elif i == "\n": + result.append("\\n") + else: + result.append(i) + return str().join(result) + +def main(_): + FLAGS.start_string = FLAGS.start_string + converter = TextConverter(filename=FLAGS.converter_path) + if os.path.isdir(FLAGS.checkpoint_path): + FLAGS.checkpoint_path =\ + tf.train.latest_checkpoint(FLAGS.checkpoint_path) + + model = CharRNN(converter.vocab_size, sampling=True, + lstm_size=FLAGS.lstm_size, num_layers=FLAGS.num_layers, + use_embedding=FLAGS.use_embedding, + embedding_size=FLAGS.embedding_size) + + model.load(FLAGS.checkpoint_path) + + start = converter.text_to_arr(FLAGS.start_string) + arr = model.predict(FLAGS.max_length, start, converter.vocab_size, 10) + for c, p in arr: + prediction = converter.arr_to_text(c) + prediction = remove_return(prediction) + + # 如果有中文字生成,请将 {1:^14} 改为 {1:{4}^14} 以修复对齐问题。 + # {1:^14}中的 14 随着生成的字符数量而定,一般可以设为字符数+4 + + print("{0} -> {1:^14} {2} {3}".format(FLAGS.start_string, prediction, "probability:", p, chr(12288))) + + +if __name__ == '__main__': + tf.app.run() From f585b4502334d4814707282f9641b352085a03ec Mon Sep 17 00:00:00 2001 From: LZY2006 <58110034+Lee-qian-gay@users.noreply.github.com> Date: Sun, 10 Jan 2021 17:48:33 +0800 Subject: [PATCH 06/12] =?UTF-8?q?Signed-off-by:=20LZY=20=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E4=BA=86=E8=87=AA=E8=BF=B0=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 614d07d..4e7cda8 100644 --- a/README.md +++ b/README.md @@ -11,16 +11,26 @@ LZY 在 HYZ 的基础上做了一些微小的改动来让此支持断点继续 - TensorFlow >= 1.2 ## How to use predict.py 预测字符的方法 -for example: 例如: +for example 例如: ``` python predict.py \ - --converter_path model/shakespeare/converter.pkl \ - --checkpoint_path model/shakespeare/ \ - --max_length 1000 + --converter_path model/torch_gen/converter.pkl \ + --checkpoint_path model/torch_gen \ + --max_length 1500 \ + --start_string " raise " ``` 使用和 sample.py 同样的参数。 Use the same parameters as sample.py. +result 结果: +``` + raise -> utized_inpu probability: 0.6539345979690552 + raise -> es()\r\n probability: 0.1654084473848343 + raise -> pistent_and probability: 0.07784435153007507 + raise -> al_module_t probability: 0.0615621916949749 + raise -> Porgex(self probability: 0.04125040024518967 +``` + ## Generate English Text 生成英文 To train: From c4e50374acea62d60b74b9a2354e30c525f37d05 Mon Sep 17 00:00:00 2001 From: LZY2006 <58110034+Lee-qian-gay@users.noreply.github.com> Date: Sun, 10 Jan 2021 17:51:33 +0800 Subject: [PATCH 07/12] =?UTF-8?q?Signed-off-by:=20=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E8=87=AA=E8=BF=B0=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4e7cda8..b999a62 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,12 @@ Multi-language Char RNN in TensorFlow. You can use this code to generate English 一个基于最新版本TensorFlow的Char RNN实现。可以实现生成英文、写诗、歌词、小说、生成代码、生成日文等功能。 LZY 在 HYZ 的基础上做了一些微小的改动来让此支持断点继续训练和Tensorflow 2.0。 +增加了predict.py。 ## Requirements 环境要求 -- Python 2.7.X -- TensorFlow >= 1.2 +- Python 2.7.X 支持 Python 3.X +- TensorFlow >= 1.2 支持 TensorFlow 2.X ## How to use predict.py 预测字符的方法 for example 例如: From d46bb327bf20b08cae3be9599f249739050e1861 Mon Sep 17 00:00:00 2001 From: LZY2006 Date: Sun, 10 Jan 2021 17:55:54 +0800 Subject: [PATCH 08/12] =?UTF-8?q?Signed-off-by:=20LZY=20=E5=88=A0=E9=99=A4?= =?UTF-8?q?=E4=B8=8D=E5=BF=85=E8=A6=81=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/settings.json | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 30e6c45..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "python.pythonPath": "C:\\Users\\zdwxx\\AppData\\Local\\Programs\\Python\\Python37\\python.exe" -} \ No newline at end of file From b9fba78ce59c0dc0bc682e97c7a9129720ef8df6 Mon Sep 17 00:00:00 2001 From: LZY2006 Date: Sat, 23 Jan 2021 14:46:47 +0800 Subject: [PATCH 09/12] Signed-off-by: LZY2006 --- data/ElderJiang.txt | 35 +++++++++++++++++++++++++++++++++++ model.py | 3 +-- 2 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 data/ElderJiang.txt diff --git a/data/ElderJiang.txt b/data/ElderJiang.txt new file mode 100644 index 0000000..5925d85 --- /dev/null +++ b/data/ElderJiang.txt @@ -0,0 +1,35 @@ +我们需要有所选择,我们希望尽可能地限制对中国发展无用的信息。我认为所有国家和政党都必须有他们自己的出版物来宣传他们的主张,我们的确有新闻自由,但是这种自由应该从属并服从于国家的利益。 +我个人非常痛恨腐败,但我不认为这个问题能在一夜之间解决。为了逐步根除腐败,我们需要依靠法治的办法,用舆论的办法、教育的办法逐步地把它解决。89年风波期间的确有学生高呼反对腐败的口号,所以在这个特定的问题上,我们党和青年学生是站在同一战线上的。但事实上还有一小撮别有用心的人,企图利用学生的热情,妄图推翻中国共产党的领导并颠覆人民民主专政权。我们不能允许这种事情发生。我们要采取坚决的措施,否则我们就会失去努力维系至今的稳定,这无论对中国还是对世界都没好处。 +我只是说明这么一个问题,就是这么大一个国家,我们有12亿多人,新闻的对国家的导向确实是很重要的。你不能用美国的价值观来对中国的现状做评判,因为你们在经济、国民教育水平上高度发达,你更不能把美国模式强加于中国。不管中国媒体还是外国媒体,我都认为有一点很重要,那就是不能歪曲事实,即便他们自己由发表言论的自由。这对中国媒体很重要,特别是我们的《人民日报》,我们的老百姓非常重视。如果它把某一个事实报道错误了,人们会信以为真。我们不像你们那儿,你们想报道什么就报道什么,就算你们的报道出了什么错也无所谓,不会有很严重的后果。比方说,我现在正在北戴河这儿跟你交谈,但是我已经看到了一些外国报纸说我已经到了大连了。如果是真的话,那我现在怎么可能坐在这里跟你聊天?我再告诉你另一件事。几个月前,我看到一则消息说“江泽民访问厦门,突遭炸弹爆炸重伤送院”,但那时候我还在北京。 +我认为部队经商是一个腐蚀剂。因为历史经验已经告诉我们,任何一个国家如果军队经商以后,没有一个不腐败的,最后必然是涣散了军心。 +1943年的时候,我还在南京上大学。那时日本人已经占领了包括南京在内的许多中国领土。他们希望中国老百姓吸鸦片烟上瘾,所以我们自发组织了反鸦片运动,捣毁了许多烟馆。当我们遇到日军拿着刺刀和枪指向我们的时候,我们就唱了抗议歌曲《毕业歌》:“同学们!大家起来!担负起天下的兴亡!听吧!满耳是大众的嗟伤;看吧!一年年国土的沦丧!” +李洪志这个人是长春人,不过他自己说他是释迦牟尼的再生,是耶稣的再生,你怎么相信?他说世界的末日就要来临,说地球要爆炸了。而且他还说我和原来的李鹏总理给他打电话,叫他把地球的毁灭延后几十年,但我们从来就没有向他请求过。他企图通过这些宣传来获得信徒的信任,他给人留下一种他很了解中国领导人的印象 其实他是一派胡言。事实上,由于他的传教,导致许多家庭破裂,许多人也因此失去生命。因此经过谨慎的考虑,我们认为“法轮功”是邪教。而且,“法轮功”的成员也从未被判过死刑。 +不需要翻译了,我知道你想说什么。我很愿意回答这个问题,因为ABC的芭芭拉沃尔特斯十年前就问过我同样的问题,而且给我看了那个照片。我说过,这个年轻人没有被逮捕,也没有受到伤害,因为坦克在他面前停下来,拒绝碾过他。芭芭拉也告诉我这个年轻人的名字,我可能忘了,但我问过在公共安全情报机关工作的负责人,他们利用尽可能多的关系网络来寻找这个年轻人,经过一个月的调查,我们确信这个年轻人并没有被捕。 +吼啊。 +啊 当然啦。 +没听到过。 +彭定康……你们媒体千万要记着,不要“见着风,是得雨”。接到这些消息,你(媒体)本身也要判断,明白这个意思吗?像这种完全……无中生有的东西,你再帮他说一遍,你等于…这个东西…你…你也有责任吧。 +「你们有一个好,全世界跑到什么地方,你们比其他的西方记者啊,跑得还快。但是呢,问来问去的问题啊,都 too simple,啊,sometimes naïve!」 +没有……没有任何(内定、钦点)的意思。还是按照香港的……按照基本法、按照选举的法——去产生…… +你……刚才你问我啊,我可以回答你一句“无可奉告”,那你们又不高兴,那怎么办? +我讲的意思不是我是钦点他当下一任。你问我不支支持不支持,我说支持。我就明确可以告诉你。 +我觉得你们啊,你们我感觉你们新闻界还要学习一个,你们非常熟悉西方这一套的媒体。你们毕竟还 too young(太年轻),明白我的意思吧?我告诉你们我是身经百战了,见得多啦!欸,西方的哪一个国家我没去过?媒体他们——你…你们要知道,美国的华莱士,比你们不知道高到哪里去啦!欸,我跟他谈笑风生!所以说媒体啊,要……还是要提高自己的知识水平!懂我的意思——識得唔識得啊?(懂不懂啊?) +唉,我也给你们着急啊,真的。 +你们今日……我以为……遍地……你们有一个好,全世界跑到什么地方,你们比其他的西方记者啊,跑得还快。但是呢,问来问去的问题啊,都 too simple(太肤浅),啊,sometimes naïve!(有时很幼稚!)懂了没啊? +识得唔识得啊? +(一片嘈杂声) +我很抱歉,我今天是作为一个长者,跟你们讲的。我不是新闻工作者,但是我见得太多了,我……我有这个必要啊告诉你们一点,人生的经验。 +记者:但是能不能说一下为甚麽支持董建华呢? +我刚才呢……我刚才我很想啊,就是我每一次碰到你们我就讲,中国有一句话叫「闷声大发财」。我就什么话也不说。就是最好的!但是我想我见到你们这样热情啊,一句话不说也不好。所以你刚才你一定要——在宣传上将来如果你们报道上有偏差你们要负责的。我没有说要钦定(董建华),没任何这个意思。但是你问……你一定要不得要问我…对对对...对董先生支持不支持。我们不支持他?他现在是当特首,我们怎么能不支持特首? +记者:但是如果说连任呢? +对不对? +欸,连任也要按照香港的法律啊,对不对?要要……要要按照香港的……当然,我们的决定权也是很重要的。香港的特区……特别行政区是属于中国……人民共和(中华人民共和国)的中央人民政府啊。啊?到那时候我们会表态的! +记者:但是呢…… +明白这个意思吧? +你们啊,不要想……喜欢……这…欸呵弄个大新闻,说现在已经钦定了,再把我批判一番。 +你们啊,naïve!(幼稚!) +记者:但是呢就是…… +保安人员:好好好OKOK…… +I'm angry!(我生气了!)我跟你讲,你们这样子啊,那不行的! +保安人员:好好好,请大家离场。 +我今天算得罪了你们一下! \ No newline at end of file diff --git a/model.py b/model.py index eccbc5d..9e46e8e 100644 --- a/model.py +++ b/model.py @@ -193,10 +193,9 @@ def predict(self, n_samples, prime, vocab_size, depth=5): feed_dict=feed) c = pick_top_n(preds, vocab_size) - # 添加字符到samples中 + samples.append(c) - # 不断生成字符,直到达到指定数目 x = np.zeros((1, 1)) x[0, 0] = c feed = {self.inputs: x, From de21f3127d6d46dbb6fc78b2e6bccb94982ca426 Mon Sep 17 00:00:00 2001 From: LZY2006 Date: Sat, 23 Jan 2021 15:09:43 +0800 Subject: [PATCH 10/12] =?UTF-8?q?Signed-off-by:=20LZY2006=20=20=E4=BF=AE=E5=A4=8DBUG=20fix=20BUGs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model.py | 31 ++++++++----------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/model.py b/model.py index 9e46e8e..fc109a8 100644 --- a/model.py +++ b/model.py @@ -145,6 +145,7 @@ def sample(self, n_samples, prime, vocab_size): for c in prime: x = np.zeros((1, 1)) # 输入单个字符 + # nai+v=e x[0, 0] = c feed = {self.inputs: x, self.keep_prob: 1., @@ -192,21 +193,10 @@ def predict(self, n_samples, prime, vocab_size, depth=5): preds, new_state = sess.run([self.proba_prediction, self.final_state], feed_dict=feed) - c = pick_top_n(preds, vocab_size) - - samples.append(c) - - x = np.zeros((1, 1)) - x[0, 0] = c - feed = {self.inputs: x, - self.keep_prob: 1., - self.initial_state: new_state} - preds, new_state = sess.run([self.proba_prediction, self.final_state], - feed_dict=feed) - + # state: naiv p = preds.copy() p = p.reshape([p.shape[1]]) - c = np.argsort(-p)[:5] + c = np.argsort(-p)[:5] # e ... p.sort() p = p[::-1][:5] p = p / np.sum(p) @@ -214,18 +204,13 @@ def predict(self, n_samples, prime, vocab_size, depth=5): result = [] for i in range(5): - c = top[0][i] - p = top[1][i] - - x = np.zeros((1, 1)) - x[0, 0] = c - feed = {self.inputs: x, - self.keep_prob: 1., - self.initial_state: new_state} - preds, new_state = sess.run([self.proba_prediction, self.final_state], - feed_dict=feed) + c = top[0][i] # e + p = top[1][i] # naiv + + # pred:e state:naiv generated = [c, ] + # generated:[e,] for i in range(depth): x = np.zeros((1, 1)) x[0, 0] = c From 2f19ccd3420a738f466831f3a92c16a0e45fcc7f Mon Sep 17 00:00:00 2001 From: LZY2006 Date: Sat, 23 Jan 2021 15:19:59 +0800 Subject: [PATCH 11/12] =?UTF-8?q?Signed-off-by:=20LZY2006=20=20=E5=A2=9E=E5=8A=A0=E4=BA=86=E7=94=A8=E4=BA=8E?= =?UTF-8?q?=E9=A2=84=E6=B5=8B=E3=80=81=E8=AE=AD=E7=BB=83=E3=80=81=E7=94=9F?= =?UTF-8?q?=E6=88=90=E7=9A=84=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 7 +++---- sample/predict_Jiang.bat | 7 +++++++ sample/predict_bilibili.bat | 7 +++++++ sample/predict_torch_code.bat | 6 ++++++ sample/sample_Jiang.bat | 7 +++++++ sample/sample_bilibili.bat | 7 +++++++ sample/sample_tf_code.bat | 5 +++++ sample/sample_three.bat | 9 +++++++++ sample/sample_torch_code.bat | 6 ++++++ sample/sample_wangZQ.bat | 9 +++++++++ train/train_Jiang.bat | 10 ++++++++++ train/train_bilibili.bat | 10 ++++++++++ train/train_tf_code.bat | 7 +++++++ train/train_three.bat | 11 +++++++++++ train/train_torch_code.bat | 7 +++++++ train/train_wangZengQi.bat | 12 ++++++++++++ 16 files changed, 123 insertions(+), 4 deletions(-) create mode 100644 sample/predict_Jiang.bat create mode 100644 sample/predict_bilibili.bat create mode 100644 sample/predict_torch_code.bat create mode 100644 sample/sample_Jiang.bat create mode 100644 sample/sample_bilibili.bat create mode 100644 sample/sample_tf_code.bat create mode 100644 sample/sample_three.bat create mode 100644 sample/sample_torch_code.bat create mode 100644 sample/sample_wangZQ.bat create mode 100644 train/train_Jiang.bat create mode 100644 train/train_bilibili.bat create mode 100644 train/train_tf_code.bat create mode 100644 train/train_three.bat create mode 100644 train/train_torch_code.bat create mode 100644 train/train_wangZengQi.bat diff --git a/.gitignore b/.gitignore index 9f90e85..7a9c79b 100644 --- a/.gitignore +++ b/.gitignore @@ -103,13 +103,12 @@ ENV/ model/ -sampling/ -training/ - data/Bilibili.txt data/TheThreeBodyProblem.txt data/WangZengQi.txt data/ZhaoHuaXiShi.txt # vscode -.vscode/ \ No newline at end of file +.vscode/ + +models.7z \ No newline at end of file diff --git a/sample/predict_Jiang.bat b/sample/predict_Jiang.bat new file mode 100644 index 0000000..d0f6164 --- /dev/null +++ b/sample/predict_Jiang.bat @@ -0,0 +1,7 @@ +cd .. +python predict.py --converter_path model/ElderJiang/converter.pkl ^ + --checkpoint_path model/ElderJiang ^ + --max_length 500 ^ + --use_embedding ^ + --num_layers 3 ^ + --start_string "sometimes na" \ No newline at end of file diff --git a/sample/predict_bilibili.bat b/sample/predict_bilibili.bat new file mode 100644 index 0000000..619241c --- /dev/null +++ b/sample/predict_bilibili.bat @@ -0,0 +1,7 @@ +cd .. +python predict.py --converter_path model/Bilibili/converter.pkl ^ + --checkpoint_path model/Bilibili ^ + --max_length 500 ^ + --use_embedding ^ + --num_layers 3 ^ + --start_string "ĸﴺ" \ No newline at end of file diff --git a/sample/predict_torch_code.bat b/sample/predict_torch_code.bat new file mode 100644 index 0000000..acbdccb --- /dev/null +++ b/sample/predict_torch_code.bat @@ -0,0 +1,6 @@ +cd .. +python predict.py ^ + --converter_path model/torch_gen/converter.pkl ^ + --checkpoint_path model/torch_gen ^ + --max_length 1500 ^ + --start_string " def " \ No newline at end of file diff --git a/sample/sample_Jiang.bat b/sample/sample_Jiang.bat new file mode 100644 index 0000000..0681a31 --- /dev/null +++ b/sample/sample_Jiang.bat @@ -0,0 +1,7 @@ +cd .. +python sample.py --converter_path model/ElderJiang/converter.pkl ^ + --checkpoint_path model/ElderJiang ^ + --max_length 500 ^ + --use_embedding ^ + --num_layers 3 ^ + --start_string "sometimes na" \ No newline at end of file diff --git a/sample/sample_bilibili.bat b/sample/sample_bilibili.bat new file mode 100644 index 0000000..2485124 --- /dev/null +++ b/sample/sample_bilibili.bat @@ -0,0 +1,7 @@ +cd .. +python sample.py --converter_path model/Bilibili/converter.pkl ^ + --checkpoint_path model/Bilibili ^ + --max_length 500 ^ + --use_embedding ^ + --num_layers 3 ^ + --start_string "我觉得" \ No newline at end of file diff --git a/sample/sample_tf_code.bat b/sample/sample_tf_code.bat new file mode 100644 index 0000000..5a45741 --- /dev/null +++ b/sample/sample_tf_code.bat @@ -0,0 +1,5 @@ +python sample.py ^ + --converter_path model/linux/converter.pkl ^ + --checkpoint_path model/linux ^ + --max_length 1500 ^ + --start_string " raise" \ No newline at end of file diff --git a/sample/sample_three.bat b/sample/sample_three.bat new file mode 100644 index 0000000..3fb0630 --- /dev/null +++ b/sample/sample_three.bat @@ -0,0 +1,9 @@ +python sample.py ^ + --converter_path model/novel/converter.pkl ^ + --checkpoint_path model/novel ^ + --use_embedding ^ + --max_length 2000 ^ + --num_layers 3 ^ + --lstm_size 256 ^ + --embedding_size 256 + --start_string "其实" \ No newline at end of file diff --git a/sample/sample_torch_code.bat b/sample/sample_torch_code.bat new file mode 100644 index 0000000..46e2ec9 --- /dev/null +++ b/sample/sample_torch_code.bat @@ -0,0 +1,6 @@ +cd .. +python sample.py ^ + --converter_path model/torch_gen/converter.pkl ^ + --checkpoint_path model/torch_gen ^ + --max_length 1500 ^ + --start_string " raise " \ No newline at end of file diff --git a/sample/sample_wangZQ.bat b/sample/sample_wangZQ.bat new file mode 100644 index 0000000..cf8c5c8 --- /dev/null +++ b/sample/sample_wangZQ.bat @@ -0,0 +1,9 @@ +python sample.py ^ + --converter_path model/Zhaohuaxishi/converter.pkl ^ + --checkpoint_path model/Zhaohuaxishi ^ + --use_embedding ^ + --max_length 1000 ^ + --num_layers 3 ^ + --lstm_size 256 ^ + --embedding_size 256 ^ + --start_string "ĺ" \ No newline at end of file diff --git a/train/train_Jiang.bat b/train/train_Jiang.bat new file mode 100644 index 0000000..d3dea0a --- /dev/null +++ b/train/train_Jiang.bat @@ -0,0 +1,10 @@ +cd .. +python train.py ^ + --input_file data/ElderJiang.txt ^ + --num_steps 20 ^ + --batch_size 32 ^ + --name ElderJiang ^ + --max_steps 10000 ^ + --learning_rate 0.01 ^ + --num_layers 3 ^ + --use_embedding \ No newline at end of file diff --git a/train/train_bilibili.bat b/train/train_bilibili.bat new file mode 100644 index 0000000..82d3573 --- /dev/null +++ b/train/train_bilibili.bat @@ -0,0 +1,10 @@ +cd .. +python train.py ^ + --input_file data/Bilibili.txt ^ + --num_steps 20 ^ + --batch_size 32 ^ + --name Bilibili ^ + --max_steps 10000 ^ + --learning_rate 0.01 ^ + --num_layers 3 ^ + --use_embedding \ No newline at end of file diff --git a/train/train_tf_code.bat b/train/train_tf_code.bat new file mode 100644 index 0000000..5a8de80 --- /dev/null +++ b/train/train_tf_code.bat @@ -0,0 +1,7 @@ +python train.py ^ + --input_file data/tensorflow_code.txt ^ + --num_steps 100 ^ + --name linux ^ + --learning_rate 0.01 ^ + --num_seqs 32 ^ + --max_steps 20000 \ No newline at end of file diff --git a/train/train_three.bat b/train/train_three.bat new file mode 100644 index 0000000..42b2e78 --- /dev/null +++ b/train/train_three.bat @@ -0,0 +1,11 @@ +python train.py ^ + --use_embedding True ^ + --input_file data/TheThreeBodyProblem.txt ^ + --num_steps 80 ^ + --name novel ^ + --learning_rate 0.005 ^ + --num_seqs 32 ^ + --num_layers 3 ^ + --embedding_size 256 ^ + --lstm_size 256 ^ + --max_steps 10000 \ No newline at end of file diff --git a/train/train_torch_code.bat b/train/train_torch_code.bat new file mode 100644 index 0000000..651103b --- /dev/null +++ b/train/train_torch_code.bat @@ -0,0 +1,7 @@ +python train.py ^ + --input_file data/torch_code.txt ^ + --num_steps 100 ^ + --name torch_gen ^ + --learning_rate 0.01 ^ + --num_seqs 32 ^ + --max_steps 20000 \ No newline at end of file diff --git a/train/train_wangZengQi.bat b/train/train_wangZengQi.bat new file mode 100644 index 0000000..14a025d --- /dev/null +++ b/train/train_wangZengQi.bat @@ -0,0 +1,12 @@ +rem WangZengQi +python train.py ^ + --use_embedding True ^ + --input_file data/ZhaoHuaXiShi.txt ^ + --num_steps 80 ^ + --name Zhaohuaxishi ^ + --learning_rate 0.005 ^ + --num_seqs 32 ^ + --num_layers 3 ^ + --embedding_size 256 ^ + --lstm_size 256 ^ + --max_steps 100000 \ No newline at end of file From 405928a1319339f2774f71cbb7fce2d2941f574f Mon Sep 17 00:00:00 2001 From: LZY <58110034+LZY2006@users.noreply.github.com> Date: Thu, 28 Jan 2021 15:12:05 +0800 Subject: [PATCH 12/12] Create LICENSE --- LICENSE | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..3ef74fc --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 LZY + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.