From 5f041b0872056bd0ed0d38c978fdf9c9f2034e54 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Sun, 8 Feb 2026 18:53:44 +0800 Subject: [PATCH] Refactor hpc-ops to use tvm-ffi, decoupling compiled binaries from specific torch version (#1) Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: oraluben <5031346+oraluben@users.noreply.github.com> --- .gitignore | 4 + CMakeLists.txt | 136 ++-- Makefile | 2 +- README.md | 28 +- cmake/FindPipCUDAToolkit.cmake | 70 +++ cmake/find_pip_cuda.py | 102 +++ conftest.py | 3 +- hpc/__init__.py | 32 +- hpc/act.py | 28 + hpc/attention.py | 54 ++ hpc/fuse_moe.py | 75 ++- hpc/group_gemm.py | 35 ++ hpc/ops/.gitkeep | 0 pyproject.toml | 59 ++ requirements-dev.txt | 4 +- setup.py | 73 --- src/C/C.cc | 7 +- src/C/built_json.cu | 6 +- src/C/version.cc | 7 +- src/activation/entry.cc | 196 +++--- src/attention/entry.cc | 502 ++++++++------- src/fuse_moe/entry.cc | 590 +++++++++--------- src/group_gemm/entry.cc | 265 ++++---- src/utils/include/tvm_ffi_utils.h | 57 ++ tests/test_act.py | 2 +- tests/test_attention_decode_bf16.py | 2 +- tests/test_attention_decode_fp8.py | 2 +- tests/test_attention_prefill_bf16.py | 2 +- ...est_attention_with_kvcache_prefill_bf16.py | 2 +- ...test_attention_with_kvcache_prefill_fp8.py | 2 +- tests/test_fuse_moe_blockwise.py | 2 +- tests/test_fuse_moe_pertensor.py | 2 +- tests/test_group_gemm_blockwise.py | 2 +- tests/test_group_gemm_pertensor.py | 2 +- tests/test_version.py | 2 +- version_with_meta.py | 41 ++ 36 files changed, 1421 insertions(+), 977 deletions(-) create mode 100644 cmake/FindPipCUDAToolkit.cmake create mode 100644 cmake/find_pip_cuda.py create mode 100644 hpc/ops/.gitkeep create mode 100644 pyproject.toml delete mode 100644 setup.py create mode 100644 src/utils/include/tvm_ffi_utils.h create mode 100644 version_with_meta.py diff --git a/.gitignore b/.gitignore index a0cae23..8f9230f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,9 @@ site/ docs/ .vscode/ *.egg-info/ +CMakeFiles/ _C.*.so +_C.so __pycache__ *.ncu-rep *.nsys-rep @@ -12,3 +14,5 @@ __pycache__ *.log test version.py +compile_commands.json +_codeql_detected_source_root diff --git a/CMakeLists.txt b/CMakeLists.txt index 81c2f30..be6a876 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,26 +1,67 @@ -cmake_minimum_required(VERSION 3.18) +cmake_minimum_required(VERSION 3.27 FATAL_ERROR) + +# Detect CUDA toolkit: tries host installation first, then falls back to +# pip-installed packages (env WITH_PIP_CUDA_TOOLCHAIN or auto-detect). +# Must be included before project() so CMAKE_CUDA_COMPILER is set. +include(${CMAKE_CURRENT_LIST_DIR}/cmake/FindPipCUDAToolkit.cmake) + project(hpc_ops LANGUAGES CXX CUDA) -enable_language(CUDA) -set(CMAKE_BUILD_TYPE "Release") +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") +set(CMAKE_CUDA_RUNTIME_LIBRARY None) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -find_package(CUDAToolkit REQUIRED) -find_package(Python3 COMPONENTS Interpreter Development.Module REQUIRED) +set(CMAKE_CUDA_ARCHITECTURES "90a") -file(GLOB_RECURSE SOURCES "src/*/*.cu" "src/*/*.cc") -list(FILTER SOURCES EXCLUDE REGEX ".*test.*") +find_package(CUDAToolkit REQUIRED) +include_directories(${CUDAToolkit_INCLUDE_DIRS}) +link_directories(${CUDAToolkit_LIBRARY_DIR} ${CUDAToolkit_LIBRARY_DIR}/stubs) + +find_program(CCACHE_FOUND ccache) +if(CCACHE_FOUND) + set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_FOUND}") + set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_FOUND}") + set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_FOUND}") +endif() + +find_package( + Python + COMPONENTS Interpreter + REQUIRED +) +find_package(tvm_ffi CONFIG REQUIRED) +if(NOT "${SKBUILD_SABI_VERSION}" STREQUAL "") + set(USE_SABI USE_SABI ${SKBUILD_SABI_VERSION}) +endif() -add_library(_C MODULE ${SOURCES}) +include_directories( + ./ + src/utils/include + 3rd/cutlass/include +) +# Collect all CUDA source files (kernels) +file(GLOB_RECURSE CUDA_SOURCES "src/*/*.cu") +list(FILTER CUDA_SOURCES EXCLUDE REGEX ".*test.*") +# Exclude built_json.cu as it's compiled separately +list(FILTER CUDA_SOURCES EXCLUDE REGEX ".*/C/built_json\\.cu$") + +# Collect all CC source files (entry points) +file(GLOB_RECURSE CC_SOURCES "src/*/*.cc") +list(FILTER CC_SOURCES EXCLUDE REGEX ".*test.*") +# Exclude C.cc placeholder +list(FILTER CC_SOURCES EXCLUDE REGEX ".*/C/C\\.cc$") + +# Build all kernels as a single shared library +add_library(_C SHARED ${CUDA_SOURCES} ${CC_SOURCES} src/C/built_json.cu) +target_link_libraries(_C PRIVATE tvm_ffi::shared cuda cudart) set_target_properties( _C PROPERTIES - LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/ - - OUTPUT_NAME "_C" PREFIX "" - SUFFIX ".abi3.so" - CUDA_RUNTIME_LIBRARY "Shared" + SUFFIX ".so" + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/ CUDA_SEPARABLE_COMPILATION OFF CUDA_RESOLVE_DEVICE_SYMBOLS ON @@ -29,7 +70,6 @@ set_target_properties( C_VISIBILITY_PRESET "hidden" CXX_VISIBILITY_PRESET "hidden" VISIBILITY_INLINES_HIDDEN ON - CUDA_VISIBILITY_PRESET "hidden" CXX_STANDARD 17 CXX_STANDARD_REQUIRED ON @@ -42,59 +82,35 @@ set_target_properties( CUDA_ARCHITECTURES "90a" ) +if(NOT DEFINED HPC_GIT_HASH_STR OR HPC_GIT_HASH_STR STREQUAL "") + execute_process( + COMMAND git rev-parse --short=7 HEAD + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + OUTPUT_VARIABLE HPC_GIT_HASH_STR + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + if(NOT HPC_GIT_HASH_STR) + set(HPC_GIT_HASH_STR "unknown") + endif() +endif() + +if(NOT DEFINED HPC_VERSION_STR OR HPC_VERSION_STR STREQUAL "") + set(HPC_VERSION_STR "0.0.1-dev") +endif() + target_compile_definitions( _C PRIVATE - Py_LIMITED_API=0x03090000 _GLIBCXX_USE_CXX11_ABI=1 - HPC_GIT_HASH_STR=${HPC_GIT_HASH_STR} - HPC_VERSION_STR=${HPC_VERSION_STR} -) - -execute_process( - COMMAND python3 -c " -from torch.utils.cpp_extension import include_paths -print(';'.join(include_paths()), end='') -" - OUTPUT_VARIABLE TORCH_INCLUDE_PATHS -) - -execute_process( - COMMAND python3 -c " -from torch.utils.cpp_extension import library_paths -print(';'.join(library_paths()), end='') -" - OUTPUT_VARIABLE TORCH_LIBRARY_PATHS -) - - -target_include_directories( - _C PRIVATE - ./ - 3rd/cutlass/include - ${CUDAToolkit_INCLUDE_DIRS} - ${TORCH_INCLUDE_PATHS} -) - -target_link_directories( - _C PRIVATE - ${TORCH_LIBRARY_PATHS} -) - -target_link_libraries( - _C PRIVATE - cuda - c10 - torch - torch_cpu - cudart - c10_cuda - torch_cuda + HPC_GIT_HASH_STR="${HPC_GIT_HASH_STR}" + HPC_VERSION_STR="${HPC_VERSION_STR}" ) target_compile_options( _C PRIVATE $<$: -Werror=all-warnings + -Wno-error=deprecated-declarations -lineinfo --expt-relaxed-constexpr -std=c++17 @@ -108,7 +124,7 @@ target_compile_options( -g -fwrapv -Wall - -DTORCH_API_INCLUDE_EXTENSION_H - -DTORCH_EXTENSION_NAME=_C > ) + +install(TARGETS _C LIBRARY DESTINATION hpc/ops) diff --git a/Makefile b/Makefile index 4ba5d99..4e77788 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ CSRC_FILES=$(CC_FILES) $(CU_FILES) $(CUH_FILES) $(H_FILES) all: - python3 setup.py build + pip install --no-build-isolation -e . wheel: find . -type d -name "__pycache__" -exec rm -rf {} + diff --git a/README.md b/README.md index a58585f..fa247ca 100644 --- a/README.md +++ b/README.md @@ -43,13 +43,37 @@ HPC-Ops is a **production-grade, high-performance, and easy-to-use** operator li *You can set up the environment by installing the modules listed in requirements-dev.txt.* ### Install from Source + +#### With host CUDA toolchain ```bash git clone https://github.com/Tencent/hpc-ops.git cd hpc-ops -# build packages +# Ensure CUDA toolkit is installed on the host (e.g., /usr/local/cuda) +pip install . -v +``` + +#### With pip-provided CUDA toolchain (no host CUDA required) + +Option A — pip toolchain in the current environment (use `--no-build-isolation`): + +```bash +pip install nvidia-cuda-nvcc nvidia-cuda-cccl scikit-build-core cmake ninja +pip install . -v --no-build-isolation +``` + +Option B — pip toolchain in another virtualenv or path: + +```bash +# Point to the cu directory inside another venv's site-packages +export WITH_PIP_CUDA_TOOLCHAIN=/path/to/venv/lib/python3.x/site-packages/nvidia/cu13 +pip install . -v +``` + +#### Build wheel +```bash make wheel -python3 -m pip install dist/*.whl +python3 -m pip install dist/*.whl ``` ### Basic Usage diff --git a/cmake/FindPipCUDAToolkit.cmake b/cmake/FindPipCUDAToolkit.cmake new file mode 100644 index 0000000..29cb3f3 --- /dev/null +++ b/cmake/FindPipCUDAToolkit.cmake @@ -0,0 +1,70 @@ +# FindPipCUDAToolkit.cmake +# +# Locate CUDA toolkit — first trying the host system, then falling back +# to pip-installed packages (nvidia-cuda-nvcc, nvidia-cuda-cccl). +# +# This module should be included BEFORE project() to set CMAKE_CUDA_COMPILER +# when pip CUDA is used. +# +# Detection order: +# 1. Try find_package(CUDAToolkit QUIET) — succeeds if a host CUDA +# installation is available; skip pip detection. +# 2. If env var WITH_PIP_CUDA_TOOLCHAIN is set to a path (e.g., .../cu13), +# use that directory directly as the CUDA toolkit root. +# 3. Otherwise, try auto-detecting from the current Python environment's +# site-packages (works with --no-build-isolation). + +# --- Try host CUDA first --- +find_package(CUDAToolkit QUIET) +if(CUDAToolkit_FOUND) + return() +endif() + +find_program(_PIP_CUDA_PYTHON_EXE NAMES python3 python) +if(NOT _PIP_CUDA_PYTHON_EXE) + return() +endif() + +# --- Strategy 1: explicit path via env var --- +if(DEFINED ENV{WITH_PIP_CUDA_TOOLCHAIN}) + set(_PIP_CUDA_ROOT "$ENV{WITH_PIP_CUDA_TOOLCHAIN}") + if(NOT EXISTS "${_PIP_CUDA_ROOT}/bin/nvcc") + message(FATAL_ERROR + "FindPipCUDAToolkit: WITH_PIP_CUDA_TOOLCHAIN is set to '${_PIP_CUDA_ROOT}' " + "but nvcc was not found at '${_PIP_CUDA_ROOT}/bin/nvcc'") + endif() + # Prepare the directory (create lib64 symlink, unversioned .so symlinks, + # libcuda.so stub) that CMake / nvcc expect but pip packages omit. + execute_process( + COMMAND "${_PIP_CUDA_PYTHON_EXE}" "${CMAKE_CURRENT_LIST_DIR}/find_pip_cuda.py" + "${_PIP_CUDA_ROOT}" + OUTPUT_QUIET + ) + message(STATUS "FindPipCUDAToolkit: using env WITH_PIP_CUDA_TOOLCHAIN=${_PIP_CUDA_ROOT}") +else() + # --- Strategy 2: auto-detect from current Python env --- + execute_process( + COMMAND "${_PIP_CUDA_PYTHON_EXE}" "${CMAKE_CURRENT_LIST_DIR}/find_pip_cuda.py" + OUTPUT_VARIABLE _PIP_CUDA_OUTPUT + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE _PIP_CUDA_RESULT + ) + + if(NOT _PIP_CUDA_RESULT EQUAL 0) + message(STATUS "FindPipCUDAToolkit: pip-installed CUDA toolkit not found") + return() + endif() + + string(JSON _PIP_CUDA_ROOT GET "${_PIP_CUDA_OUTPUT}" "root") + message(STATUS "FindPipCUDAToolkit: auto-detected from Python environment") +endif() + +# --- Common pip-CUDA setup --- +set(CMAKE_CUDA_COMPILER "${_PIP_CUDA_ROOT}/bin/nvcc" CACHE FILEPATH "CUDA compiler (from pip)" FORCE) +set(CUDAToolkit_ROOT "${_PIP_CUDA_ROOT}" CACHE PATH "CUDA toolkit root (from pip)" FORCE) + +list(APPEND CMAKE_LIBRARY_PATH "${_PIP_CUDA_ROOT}/lib/stubs" "${_PIP_CUDA_ROOT}/lib") + +message(STATUS "FindPipCUDAToolkit: using pip-installed CUDA toolkit") +message(STATUS " nvcc: ${CMAKE_CUDA_COMPILER}") +message(STATUS " root: ${CUDAToolkit_ROOT}") diff --git a/cmake/find_pip_cuda.py b/cmake/find_pip_cuda.py new file mode 100644 index 0000000..645c1aa --- /dev/null +++ b/cmake/find_pip_cuda.py @@ -0,0 +1,102 @@ +"""Locate pip-installed CUDA toolkit and prepare it for CMake consumption. + +Used by cmake/FindPipCUDAToolkit.cmake via ``execute_process``. +Outputs a JSON object with paths on success, exits with code 1 on failure. + +Usage: + python find_pip_cuda.py # auto-detect from current env + python find_pip_cuda.py /path/to/cu13 # use explicit path, just prepare it +""" + +import json +import pathlib +import subprocess +import sys + + +def _find_cu_dir(): + """Find the nvidia/cu directory from the nvidia pip package.""" + try: + import nvidia + except ImportError: + return None + + nvidia_dir = pathlib.Path(nvidia.__path__[0]) + cu_dirs = sorted( + (d for d in nvidia_dir.iterdir() if d.name[:2] == "cu" and d.name[2:].isdigit()), + key=lambda d: int(d.name[2:]), + ) + if not cu_dirs: + return None + cu_dir = cu_dirs[-1] + if (cu_dir / "bin" / "nvcc").is_file(): + return cu_dir + return None + + +def _ensure_lib_symlinks(cu_dir): + """Create symlinks that CMake / nvcc expect but pip packages omit.""" + lib_dir = cu_dir / "lib" + if not lib_dir.is_dir(): + return + + # nvcc expects lib64/ on 64-bit + lib64 = cu_dir / "lib64" + if not lib64.exists(): + try: + lib64.symlink_to("lib") + except OSError: + pass + + # CMake expects unversioned .so (e.g., libcudart.so) + for so in lib_dir.glob("*.so.*"): + base = lib_dir / (so.name.split(".so.")[0] + ".so") + if not base.exists(): + try: + base.symlink_to(so.name) + except OSError: + pass + + +def _ensure_cuda_stub(cu_dir): + """Create a minimal libcuda.so stub for build-time -lcuda linking.""" + stubs_dir = cu_dir / "lib" / "stubs" + stub = stubs_dir / "libcuda.so" + if stub.exists(): + return + stubs_dir.mkdir(parents=True, exist_ok=True) + src = stubs_dir / "_stub.c" + try: + src.write_text("void cuGetErrorString(void){}\n") + subprocess.check_call( + ["gcc", "-shared", "-o", str(stub), str(src)], + stderr=subprocess.DEVNULL, + ) + except Exception: + pass + finally: + src.unlink(missing_ok=True) + + +def main(): + if len(sys.argv) > 1: + # Explicit path provided — just prepare it + cu_dir = pathlib.Path(sys.argv[1]) + else: + # Auto-detect from current Python environment + cu_dir = _find_cu_dir() + + if cu_dir is None or not (cu_dir / "bin" / "nvcc").is_file(): + sys.exit(1) + + _ensure_lib_symlinks(cu_dir) + _ensure_cuda_stub(cu_dir) + + print(json.dumps({ + "nvcc": str(cu_dir / "bin" / "nvcc"), + "root": str(cu_dir), + })) + + +if __name__ == "__main__": + main() diff --git a/conftest.py b/conftest.py index 8d6efbd..67d3096 100644 --- a/conftest.py +++ b/conftest.py @@ -103,7 +103,7 @@ def wrapped(*args, **kwargs): ret = org_func(*args, **kwargs) save_data(tmp_after_invoke_file, "hpc", func_name, ret, args, kwargs) - pypath = os.path.realpath(list(Path(__file__).parent.glob("./build/lib.*/"))[0]) + pypath = str(Path(__file__).parent) dump_test_py(tmp_py_file, tmp_before_invoke_file, tmp_after_invoke_file, pypath) print(tmp_py_file) @@ -123,7 +123,6 @@ def wrapped(*args, **kwargs): return True def hook(self): - sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("./build/lib.*/"))[0])) module = __import__(self.module_name) dirs = dir(module) diff --git a/hpc/__init__.py b/hpc/__init__.py index 7e2086f..58d76af 100644 --- a/hpc/__init__.py +++ b/hpc/__init__.py @@ -1,13 +1,32 @@ -import torch +from functools import lru_cache import importlib -import os import sys from pathlib import Path from types import ModuleType from typing import Dict +import torch +import tvm_ffi + + _pkg_dir = Path(__file__).parent +LIB_ROOT = _pkg_dir.parent / "build" +if not LIB_ROOT.exists(): + LIB_ROOT = _pkg_dir / "ops" + +# Define the torch library for op registration (needed for torch.compile tracing) +_torch_lib = torch.library.Library("hpc", "DEF") + + +@lru_cache(maxsize=None) +def load_ffi_lib(name: str): + """ + Libraries would be in `/build` or `/hpc/ops`. + """ + p = Path(name) + return tvm_ffi.load_module(LIB_ROOT / p.name) + def _discover_modules() -> Dict[str, ModuleType]: modules = {} @@ -40,16 +59,13 @@ def _export_functions(modules: Dict[str, ModuleType]): __all__.extend(funcs.keys()) -so_files = list(Path(__file__).parent.glob("_C.*.so")) -assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" -torch.ops.load_library(so_files[0]) - __all__ = [] _export_functions(_discover_modules()) -__version__ = torch.ops.hpc.version() -__built_json__ = torch.ops.hpc.built_json() +_lib = load_ffi_lib("_C.so") +__version__ = _lib.version() +__built_json__ = _lib.built_json() __doc__ = """ High Performance Computing Operators Library diff --git a/hpc/act.py b/hpc/act.py index f5b909c..a38eca3 100644 --- a/hpc/act.py +++ b/hpc/act.py @@ -3,6 +3,34 @@ import torch from torch import Tensor +from hpc import load_ffi_lib + +_lib = load_ffi_lib("_C.so") + +_torch_lib = torch.library.Library("hpc", "FRAGMENT") + +_torch_lib.define( + "act_mul_and_quant(Tensor input, Tensor scale, bool use_bf16_mul, Tensor? output) -> (Tensor)" +) +_torch_lib.impl("act_mul_and_quant", + lambda input, scale, use_bf16_mul, output: _lib.act_mul_and_quant(input, scale, use_bf16_mul, output), + "CUDA") + +_torch_lib.define( + "masked_act_mul_and_quant(Tensor input, Tensor scale, Tensor num_per_expert, Tensor? output) -> (Tensor)" +) +_torch_lib.impl("masked_act_mul_and_quant", + lambda input, scale, num_per_expert, output: _lib.masked_act_mul_and_quant(input, scale, num_per_expert, output), + "CUDA") + +_torch_lib.define( + "masked_act_mul_and_blockwise_quant(Tensor input, Tensor num_per_expert, Tensor? output, " + "Tensor? output_scale) -> (Tensor output, Tensor output_scale)" +) +_torch_lib.impl("masked_act_mul_and_blockwise_quant", + lambda input, num_per_expert, output, output_scale: _lib.masked_act_mul_and_blockwise_quant(input, num_per_expert, output, output_scale), + "CUDA") + def act_mul_and_quant( gate_up: Tensor, scale: Tensor, use_bf16_mul: bool = True, output: Tensor = None diff --git a/hpc/attention.py b/hpc/attention.py index 1587d29..98faeec 100644 --- a/hpc/attention.py +++ b/hpc/attention.py @@ -1,6 +1,60 @@ import torch from torch import Tensor +from hpc import load_ffi_lib + +_lib = load_ffi_lib("_C.so") + +_torch_lib = torch.library.Library("hpc", "FRAGMENT") + +_torch_lib.define( + "attention_prefill_bf16(Tensor q, Tensor k, Tensor v, Tensor seqlens_q, Tensor cu_seqlens_q, " + "int max_seqlens_q, Tensor? output) -> (Tensor)" +) +_torch_lib.impl("attention_prefill_bf16", + lambda q, k, v, seqlens_q, cu_seqlens_q, max_seqlens_q, output: + _lib.attention_prefill_bf16(q, k, v, seqlens_q, cu_seqlens_q, max_seqlens_q, output), + "CUDA") + +_torch_lib.define( + "attention_with_kvcache_prefill_bf16(Tensor q, Tensor kcache, Tensor vcache," + "Tensor cu_seqlens_q, " + "Tensor block_ids, Tensor num_seq_kvcache, int max_seqlens_q, Tensor? output) -> (Tensor)" +) +_torch_lib.impl("attention_with_kvcache_prefill_bf16", + lambda q, kcache, vcache, cu_seqlens_q, block_ids, num_seq_kvcache, max_seqlens_q, output: + _lib.attention_with_kvcache_prefill_bf16(q, kcache, vcache, cu_seqlens_q, block_ids, num_seq_kvcache, max_seqlens_q, output), + "CUDA") + +_torch_lib.define( + "attention_with_kvcache_prefill_fp8(Tensor q, Tensor kcache, Tensor vcache," + "Tensor qkscale, Tensor vscale, Tensor cu_seqlens_q," + "Tensor block_ids, Tensor num_seq_kvcache, int max_seqlens_q, Tensor? output) -> (Tensor)" +) +_torch_lib.impl("attention_with_kvcache_prefill_fp8", + lambda q, kcache, vcache, qkscale, vscale, cu_seqlens_q, block_ids, num_seq_kvcache, max_seqlens_q, output: + _lib.attention_with_kvcache_prefill_fp8(q, kcache, vcache, qkscale, vscale, cu_seqlens_q, block_ids, num_seq_kvcache, max_seqlens_q, output), + "CUDA") + +_torch_lib.define( + "attention_decode_bf16(Tensor q, Tensor! kcache, Tensor! vcache, Tensor block_ids, Tensor " + "num_seq_kvcache, bool new_kv_included, bool use_splitk, Tensor? output) -> (Tensor)" +) +_torch_lib.impl("attention_decode_bf16", + lambda q, kcache, vcache, block_ids, num_seq_kvcache, new_kv_included, use_splitk, output: + _lib.attention_decode_bf16(q, kcache, vcache, block_ids, num_seq_kvcache, new_kv_included, use_splitk, output), + "CUDA") + +_torch_lib.define( + "attention_decode_fp8(Tensor q, Tensor! kcache, Tensor! vcache, Tensor block_ids, Tensor " + "num_seq_kvcache, Tensor qscale, Tensor kscale, Tensor vscale, bool new_kv_included, bool " + "use_splitk, Tensor? split_flag, Tensor? output) -> (Tensor)" +) +_torch_lib.impl("attention_decode_fp8", + lambda q, kcache, vcache, block_ids, num_seq_kvcache, qscale, kscale, vscale, new_kv_included, use_splitk, split_flag, output: + _lib.attention_decode_fp8(q, kcache, vcache, block_ids, num_seq_kvcache, qscale, kscale, vscale, new_kv_included, use_splitk, split_flag, output), + "CUDA") + def attention_prefill_bf16( q: Tensor, diff --git a/hpc/fuse_moe.py b/hpc/fuse_moe.py index e698d21..4e8143f 100644 --- a/hpc/fuse_moe.py +++ b/hpc/fuse_moe.py @@ -2,6 +2,55 @@ from torch import Tensor from typing import Tuple +from hpc import load_ffi_lib + +_lib = load_ffi_lib("_C.so") + +_torch_lib = torch.library.Library("hpc", "FRAGMENT") + +_torch_lib.define( + "count_and_gather(Tensor x, Tensor topk_ids, int num_expert, int rank_ep, int " + "intermediate_size, int num_seq_per_group_avg" + ") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)" +) +_torch_lib.impl("count_and_gather", + lambda x, topk_ids, num_expert, rank_ep, intermediate_size, num_seq_per_group_avg: + _lib.count_and_gather(x, topk_ids, num_expert, rank_ep, intermediate_size, num_seq_per_group_avg), + "CUDA") + +_torch_lib.define( + "reduce(Tensor x, Tensor topk_pos, Tensor topk_scale, Tensor? shared_output" + ") -> (Tensor)" +) +_torch_lib.impl("reduce", + lambda x, topk_pos, topk_scale, shared_output: + _lib.reduce(x, topk_pos, topk_scale, shared_output), + "CUDA") + +_torch_lib.define( + "fuse_moe_pertensor_fp8(Tensor x, Tensor gate_up_weight, Tensor down_weight, Tensor " + "gate_up_scale, " + "Tensor down_scale, Tensor act_and_mul_scale, Tensor topk_ids, Tensor topk_scale, Tensor? " + "shared_output, " + "int rank_ep, int num_expert_total, bool use_bf16_mul) -> (Tensor)" +) +_torch_lib.impl("fuse_moe_pertensor_fp8", + lambda x, gate_up_weight, down_weight, gate_up_scale, down_scale, act_and_mul_scale, topk_ids, topk_scale, shared_output, rank_ep, num_expert_total, use_bf16_mul: + _lib.fuse_moe_pertensor_fp8(x, gate_up_weight, down_weight, gate_up_scale, down_scale, act_and_mul_scale, topk_ids, topk_scale, shared_output, rank_ep, num_expert_total, use_bf16_mul), + "CUDA") + +_torch_lib.define( + "fuse_moe_blockwise_fp8(Tensor x, Tensor x_scale, Tensor gate_up_weight, Tensor " + "gate_up_weight_scale, " + "Tensor down_weight, Tensor down_weight_scale, Tensor topk_ids, Tensor topk_scale, Tensor? " + "shared_output, " + "int rank_ep, int num_expert_total) -> (Tensor)" +) +_torch_lib.impl("fuse_moe_blockwise_fp8", + lambda x, x_scale, gate_up_weight, gate_up_weight_scale, down_weight, down_weight_scale, topk_ids, topk_scale, shared_output, rank_ep, num_expert_total: + _lib.fuse_moe_blockwise_fp8(x, x_scale, gate_up_weight, gate_up_weight_scale, down_weight, down_weight_scale, topk_ids, topk_scale, shared_output, rank_ep, num_expert_total), + "CUDA") + def count_and_gather( x: Tensor, @@ -11,14 +60,14 @@ def count_and_gather( intermediate_size: int, num_seq_per_group_avg: int, ) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, ]: """Sorts and aggregates token based on expert assignments for MoE layers. @@ -78,7 +127,7 @@ def count_and_gather( - The function modifies the output buffers in-place when provided - Expert assignments in topk_ids should be in range [0, num_expert-1] """ - return torch.ops.hpc.count_and_gather(x, topk_ids, num_expert, rank_ep, intermediate_size) + return torch.ops.hpc.count_and_gather(x, topk_ids, num_expert, rank_ep, intermediate_size, num_seq_per_group_avg) def reduce( @@ -306,11 +355,12 @@ def count_and_gather_fake( torch.empty((num_expert), dtype=torch.int32), torch.empty((num_expert + 1), dtype=torch.int32), torch.empty((num_expert * 2 * 128), dtype=torch.int8), + torch.empty((num_expert * 2 * 128), dtype=torch.int8), # dowm_tmas ) @torch.library.register_fake("hpc::reduce") -def reduce_fake(x, topk_pos, topk_scale): +def reduce_fake(x, topk_pos, topk_scale, shared_output=None): return torch.empty((topk_pos.shape[0], x.shape[1]), dtype=torch.bfloat16) @@ -324,6 +374,7 @@ def fuse_moe_pertensor_fp8_fake( act_and_mul_scale, topk_ids, topk_scale, + shared_output, rank_ep, num_expert_total, use_bf16_mul, @@ -341,8 +392,8 @@ def fuse_moe_blockwise_fp8_fake( down_weight_scale: Tensor, topk_ids: Tensor, topk_scale: Tensor, + shared_output, rank_ep: int, num_expert_total: int, - use_bf16_mul: bool = True, ): - return torch.empty((x.shape[0], x.shape[1]), dtype=torch.bfloat16) + return torch.empty((x.shape[0], x.shape[1]), dtype=torch.bfloat16) \ No newline at end of file diff --git a/hpc/group_gemm.py b/hpc/group_gemm.py index 0631c2e..56c3891 100644 --- a/hpc/group_gemm.py +++ b/hpc/group_gemm.py @@ -2,6 +2,41 @@ from torch import Tensor from typing import Tuple, Optional +from hpc import load_ffi_lib + +_lib = load_ffi_lib("_C.so") + +_torch_lib = torch.library.Library("hpc", "FRAGMENT") + +_torch_lib.define( + "reformat_x_scale(Tensor x_scale, Tensor seqlens, Tensor cu_seqlens, " + "Tensor? out_x_scale, int num_seq_per_group_avg) -> (Tensor)" +) +_torch_lib.impl("reformat_x_scale", + lambda x_scale, seqlens, cu_seqlens, out_x_scale, num_seq_per_group_avg: + _lib.reformat_x_scale(x_scale, seqlens, cu_seqlens, out_x_scale, num_seq_per_group_avg), + "CUDA") + +_torch_lib.define( + "group_gemm_pertensor_fp8(Tensor x, Tensor weight, Tensor seqlens, Tensor cu_seqlens, Tensor " + "y_scale, " + "int num_seq_per_group_avg, Tensor? output, Tensor? tma_desc) -> (Tensor)" +) +_torch_lib.impl("group_gemm_pertensor_fp8", + lambda x, weight, seqlens, cu_seqlens, y_scale, num_seq_per_group_avg, output, tma_desc: + _lib.group_gemm_pertensor_fp8(x, weight, seqlens, cu_seqlens, y_scale, num_seq_per_group_avg, output, tma_desc), + "CUDA") + +_torch_lib.define( + "group_gemm_blockwise_fp8(Tensor x, Tensor weight, Tensor seqlens, Tensor cu_seqlens, Tensor " + "xscale, Tensor wscale," + "int num_seq_per_group_avg, Tensor? output, Tensor? tma_desc) -> (Tensor)" +) +_torch_lib.impl("group_gemm_blockwise_fp8", + lambda x, weight, seqlens, cu_seqlens, xscale, wscale, num_seq_per_group_avg, output, tma_desc: + _lib.group_gemm_blockwise_fp8(x, weight, seqlens, cu_seqlens, xscale, wscale, num_seq_per_group_avg, output, tma_desc), + "CUDA") + def reformat_x_scale( x_scale: Tensor, diff --git a/hpc/ops/.gitkeep b/hpc/ops/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7439b20 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,59 @@ +[project] +name = "hpc-ops" +dynamic = ["version"] +description = "High Performance Computing Operator" +dependencies = ["apache-tvm-ffi", "torch>=2.4.0"] + +[project.optional-dependencies] +cuda = ["nvidia-cuda-nvcc", "nvidia-cuda-cccl"] + +[build-system] +requires = ["scikit-build-core", "apache-tvm-ffi>=0.1.8.post2"] +build-backend = "scikit_build_core.build" + +[tool.scikit-build] +wheel.packages = ["hpc"] + +metadata.version.provider = "version_with_meta" +metadata.version.provider-path = "." +experimental = true +build-dir = "build" + +wheel.py-api = "cp38" +cmake.version = ">=3.27" + +wheel.exclude = ["*.cu", "*.h", "*.cuh", "*.cc"] + +[tool.cibuildwheel] +archs = ["auto64"] +manylinux-x86_64-image = "manylinux_2_28" +manylinux-aarch64-image = "manylinux_2_28" +skip = "*musllinux*" +build = "cp38-*" +environment-pass = ["CUDA_VERSION"] + +[tool.cibuildwheel.linux] +repair-wheel-command = [ + "auditwheel repair --exclude libtvm_ffi.so --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}", + "pipx run abi3audit --strict --report {wheel}", +] +environment.PATH = "/usr/local/cuda/bin:$PATH" +before-all = """ +set -eux + +case "$(uname -m)" in + "x86_64") + dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo + ;; + "aarch64") + dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo + ;; + *) + exit 1 + ;; +esac + +cudaver="$(echo "${CUDA_VERSION:-"12.8"}" | cut -d '.' -f-2)" +v="${cudaver//./-}" +yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}" nvidia-driver-cuda-libs +""" diff --git a/requirements-dev.txt b/requirements-dev.txt index edf7020..8d40b97 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,7 @@ numpy==2.3.1 -torch==2.7.0 +torch>=2.4.0 +apache-tvm-ffi>=0.1.8.post2 +scikit-build-core pytest==8.4.1 setuptools==68.0.0 wheel==0.45.1 diff --git a/setup.py b/setup.py deleted file mode 100644 index 511c16a..0000000 --- a/setup.py +++ /dev/null @@ -1,73 +0,0 @@ -import os -import shutil -import subprocess -import sys - -from setuptools import Extension, setup -from setuptools.command.build_ext import build_ext - - -class CMakeExtension(Extension): - def __init__(self, name, version_macros=[], sourcedir=""): - Extension.__init__(self, name, sources=[]) - self.version_macros = version_macros - self.sourcedir = os.path.abspath(sourcedir) - - -class CMakeBuild(build_ext): - def run(self): - for ext in self.extensions: - self.build_extension(ext) - - def build_extension(self, ext): - build_lib_dir = os.path.dirname(self.get_ext_fullpath(ext.name)) - build_temp_dir = os.path.join(self.build_temp, ext.name) - - os.makedirs(build_lib_dir, exist_ok=True) - os.makedirs(build_temp_dir, exist_ok=True) - - cmake_args = [f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={build_lib_dir}", *ext.version_macros] - - subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp_dir) - subprocess.check_call( - ["cmake", "--build", ".", "--config", "Release", "-j16"], cwd=build_temp_dir - ) - - so_src_path = os.path.join(build_temp_dir, "_C.abi3.so") - so_dst_path = os.path.join(build_lib_dir, "hpc/_C.abi3.so") - shutil.copy(so_src_path, so_dst_path) - - -def get_version(): - git_hash = subprocess.check_output( - ["git", "rev-parse", "--short=7", "HEAD"], stderr=subprocess.DEVNULL, text=True - ).strip() - - return f"0.0.1.dev0+g{git_hash}", git_hash - - -version, git_hash = get_version() -version_macros = [ - '-DHPC_VERSION_STR="{}"'.format(version), - '-DHPC_GIT_HASH_STR="{}"'.format(git_hash), -] - -with open("hpc/version.py", "w") as fp: - fp.write('version = "{}"\n'.format(version)) - fp.write('git_hash = "{}"\n'.format(git_hash)) - -setup( - name="hpc-ops", - version=version, - description="High Performance Computing Operator", - author="Tencent hpc-ops authors", - author_email="authors@hpc-ops", - url="https://github.com/Tencent/hpc-ops", - license="Copyright (C) 2026 Tencent.", - packages=["hpc"], - ext_modules=[CMakeExtension("hpc", version_macros)], - cmdclass={"build_ext": CMakeBuild}, - package_data={"_C": ["*.so"]}, - options={"bdist_wheel": {"py_limited_api": "cp39"}}, - install_requires=["torch"], -) diff --git a/src/C/C.cc b/src/C/C.cc index 1f7b0d3..0df394d 100644 --- a/src/C/C.cc +++ b/src/C/C.cc @@ -1,5 +1,4 @@ // Copyright (C) 2026 Tencent. - -#include - -TORCH_LIBRARY(hpc, m) {} +// This file is intentionally left as a placeholder. +// With tvm-ffi, each module registers its own exported functions +// via TVM_FFI_DLL_EXPORT_TYPED_FUNC macros in their respective entry files. diff --git a/src/C/built_json.cu b/src/C/built_json.cu index 2a91b69..d933138 100644 --- a/src/C/built_json.cu +++ b/src/C/built_json.cu @@ -1,7 +1,7 @@ // Copyright (C) 2026 Tencent. #include -#include +#include #include #include @@ -15,7 +15,7 @@ #define HPC_GIT_HASH_STR "unknown" #endif -static const std::string built_json() { +static std::string built_json() { std::ostringstream oss; // NOLINTBEGIN clang-format off @@ -42,4 +42,4 @@ static const std::string built_json() { return oss.str(); } -TORCH_LIBRARY_FRAGMENT(hpc, m) { m.def("built_json", &built_json); } +TVM_FFI_DLL_EXPORT_TYPED_FUNC(built_json, built_json); diff --git a/src/C/version.cc b/src/C/version.cc index 6bbf876..93a6c7e 100644 --- a/src/C/version.cc +++ b/src/C/version.cc @@ -1,14 +1,13 @@ // Copyright (C) 2026 Tencent. -#include +#include -#include #include #ifndef HPC_VERSION_STR #define HPC_VERSION_STR "unknown" #endif -static const std::string version() { return HPC_VERSION_STR; } +static std::string version() { return HPC_VERSION_STR; } -TORCH_LIBRARY_FRAGMENT(hpc, m) { m.def("version", &version); } +TVM_FFI_DLL_EXPORT_TYPED_FUNC(version, version); diff --git a/src/activation/entry.cc b/src/activation/entry.cc index 481fdac..bb9364f 100644 --- a/src/activation/entry.cc +++ b/src/activation/entry.cc @@ -1,46 +1,55 @@ // Copyright (C) 2026 Tencent. -#include #include -#include -#include +#include +#include +#include +#include +#include + +#include #include #include +#include "tvm_ffi_utils.h" + #include "src/activation/activation.h" namespace hpc { namespace activation { -torch::Tensor act_mul_and_quant_entry(const torch::Tensor &input, const torch::Tensor &scale, - bool use_bf16_mul, std::optional output) { - auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); +tvm::ffi::Tensor act_mul_and_quant_entry(const tvm::ffi::TensorView &input, + const tvm::ffi::TensorView &scale, bool use_bf16_mul, + tvm::ffi::Optional output) { + auto stream = TVM_FFI_GET_CUDA_STREAM(input); + auto device = input.device(); - std::vector output_shape(input.sizes().begin(), input.sizes().end()); - output_shape[output_shape.size() - 1] /= 2; - - auto options = input.options().dtype(torch::kFloat8_e4m3fn); + int ndim = input.ndim(); + std::vector output_shape; + for (int i = 0; i < ndim; ++i) { + output_shape.push_back(input.shape().at(i)); + } + output_shape[ndim - 1] /= 2; - torch::Tensor output_tensor; + tvm::ffi::Tensor output_tensor; if (output.has_value()) { - output_tensor = output.value(); + output_tensor = tvm::ffi::Tensor(output.value()); } else { - output_tensor = torch::empty(output_shape, options); + output_tensor = tvm_ffi_empty(output_shape, dl_float8_e4m3, device); } using Tin = __nv_bfloat16; using Tout = __nv_fp8_e4m3; - const auto *input_ptr = reinterpret_cast(input.const_data_ptr()); - auto *output_ptr = reinterpret_cast(output_tensor.mutable_data_ptr()); - const float *scale_ptr = scale.const_data_ptr(); + const auto *input_ptr = reinterpret_cast(input.data_ptr()); + auto *output_ptr = reinterpret_cast(output_tensor.data_ptr()); + const float *scale_ptr = reinterpret_cast(scale.data_ptr()); - auto input_shape = input.sizes(); - int num_col = input_shape[input_shape.size() - 1]; + int num_col = input.shape().at(ndim - 1); int num_row = 1; - for (uint32_t i = 0; i < input_shape.size() - 1; ++i) { - num_row *= input_shape[i]; + for (int i = 0; i < ndim - 1; ++i) { + num_row *= input.shape().at(i); } act_mul_and_quant_async(output_ptr, input_ptr, scale_ptr, num_row, num_col, use_bf16_mul, stream); @@ -48,48 +57,53 @@ torch::Tensor act_mul_and_quant_entry(const torch::Tensor &input, const torch::T return output_tensor; } -torch::Tensor masked_act_mul_and_quant_entry(const torch::Tensor &input, torch::Tensor &scale, - const torch::Tensor &num_per_expert, - std::optional output) { - auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); - - TORCH_CHECK(input.is_contiguous(), "input tensor must be contiguous"); - TORCH_CHECK(scale.is_contiguous(), "scale tensor must be contiguous"); - TORCH_CHECK(num_per_expert.is_contiguous(), "num_per_expert tensor must be contiguous"); - - TORCH_CHECK(input.device().is_cuda(), "input tensor's device must be cuda"); - TORCH_CHECK(scale.device().is_cuda(), "scale tensor's device must be cuda"); - TORCH_CHECK(num_per_expert.device().is_cuda(), "num_per_expert tensor's device must be cuda"); - - TORCH_CHECK(input.size(-1) / 2 % 8 == 0, "hidden dim must be divided by 8") - TORCH_CHECK(scale.numel() == 1, "only support per tensor qunat") - - std::vector output_shape(input.sizes().begin(), input.sizes().end()); - output_shape[output_shape.size() - 1] /= 2; - - auto options = input.options().dtype(torch::kFloat8_e4m3fn); +tvm::ffi::Tensor masked_act_mul_and_quant_entry(const tvm::ffi::TensorView &input, + const tvm::ffi::TensorView &scale, + const tvm::ffi::TensorView &num_per_expert, + tvm::ffi::Optional output) { + auto stream = TVM_FFI_GET_CUDA_STREAM(input); + auto device = input.device(); + + TVM_FFI_CHECK_CONTIGUOUS(input); + TVM_FFI_CHECK_CONTIGUOUS(scale); + TVM_FFI_CHECK_CONTIGUOUS(num_per_expert); + + TVM_FFI_CHECK_CUDA(input); + TVM_FFI_CHECK_CUDA(scale); + TVM_FFI_CHECK_CUDA(num_per_expert); + + TVM_FFI_ICHECK(input.shape().at(input.ndim() - 1) / 2 % 8 == 0) + << "hidden dim must be divided by 8"; + TVM_FFI_ICHECK(num_per_expert.numel() > 0) << "num_per_expert must not be empty"; + + int ndim = input.ndim(); + std::vector output_shape; + for (int i = 0; i < ndim; ++i) { + output_shape.push_back(input.shape().at(i)); + } + output_shape[ndim - 1] /= 2; - torch::Tensor output_tensor; + tvm::ffi::Tensor output_tensor; if (output.has_value()) { - output_tensor = output.value(); + output_tensor = tvm::ffi::Tensor(output.value()); } else { - output_tensor = torch::empty(output_shape, options); + output_tensor = tvm_ffi_empty(output_shape, dl_float8_e4m3, device); } using Tin = __nv_bfloat16; using Tout = __nv_fp8_e4m3; - const auto *input_ptr = reinterpret_cast(input.const_data_ptr()); - const auto *scale_ptr = reinterpret_cast(scale.const_data_ptr()); - auto *output_ptr = reinterpret_cast(output_tensor.mutable_data_ptr()); + const auto *input_ptr = reinterpret_cast(input.data_ptr()); + const auto *scale_ptr = reinterpret_cast(scale.data_ptr()); + auto *output_ptr = reinterpret_cast(output_tensor.data_ptr()); - const auto *num_per_expert_ptr = num_per_expert.const_data_ptr(); + const auto *num_per_expert_ptr = reinterpret_cast(num_per_expert.data_ptr()); - int num_experts = num_per_expert.size(0); - int num_total_tokens = input.size(0); + int num_experts = num_per_expert.shape().at(0); + int num_total_tokens = input.shape().at(0); int num_tokens_per_expert = num_total_tokens / num_experts; - int num_intermediate_size = input.size(1) / 2; + int num_intermediate_size = input.shape().at(1) / 2; masked_act_mul_and_quant_async(output_ptr, input_ptr, scale_ptr, num_per_expert_ptr, num_total_tokens, num_intermediate_size, num_tokens_per_expert, @@ -98,79 +112,73 @@ torch::Tensor masked_act_mul_and_quant_entry(const torch::Tensor &input, torch:: return output_tensor; } -std::tuple masked_act_mul_and_blockwise_quant_entry( - const torch::Tensor &input, const torch::Tensor &num_per_expert, - std::optional output, std::optional output_scale) { - auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); +tvm::ffi::Tuple masked_act_mul_and_blockwise_quant_entry( + const tvm::ffi::TensorView &input, const tvm::ffi::TensorView &num_per_expert, + tvm::ffi::Optional output, + tvm::ffi::Optional output_scale) { + auto stream = TVM_FFI_GET_CUDA_STREAM(input); + auto device = input.device(); - TORCH_CHECK(input.is_contiguous(), "input tensor must be contiguous"); - TORCH_CHECK(num_per_expert.is_contiguous(), "num_per_expert tensor must be contiguous"); + TVM_FFI_CHECK_CONTIGUOUS(input); + TVM_FFI_CHECK_CONTIGUOUS(num_per_expert); - TORCH_CHECK(input.device().is_cuda(), "input tensor's device must be cuda"); - TORCH_CHECK(num_per_expert.device().is_cuda(), "num_per_expert tensor's device must be cuda"); + TVM_FFI_CHECK_CUDA(input); + TVM_FFI_CHECK_CUDA(num_per_expert); - TORCH_CHECK(input.size(-1) / 2 % 128 == 0, "hidden dim must be divided by 128") + TVM_FFI_ICHECK(input.shape().at(input.ndim() - 1) / 2 % 128 == 0) + << "hidden dim must be divided by 128"; - std::vector output_shape(input.sizes().begin(), input.sizes().end()); - output_shape[output_shape.size() - 1] /= 2; + int ndim = input.ndim(); + std::vector output_shape; + for (int i = 0; i < ndim; ++i) { + output_shape.push_back(input.shape().at(i)); + } + output_shape[ndim - 1] /= 2; auto output_scale_shape = output_shape; output_scale_shape[output_scale_shape.size() - 1] /= 128; - auto options = input.options(); - - torch::Tensor output_tensor; + tvm::ffi::Tensor output_tensor; if (output.has_value()) { - output_tensor = output.value(); + output_tensor = tvm::ffi::Tensor(output.value()); } else { - output_tensor = torch::empty(output_shape, options.dtype(torch::kFloat8_e4m3fn)); + output_tensor = tvm_ffi_empty(output_shape, dl_float8_e4m3, device); } - torch::Tensor output_scale_tensor; + tvm::ffi::Tensor output_scale_tensor; if (output_scale.has_value()) { - output_scale_tensor = output_scale.value(); + output_scale_tensor = tvm::ffi::Tensor(output_scale.value()); } else { - output_scale_tensor = torch::empty(output_scale_shape, options.dtype(torch::kFloat32)); + output_scale_tensor = tvm_ffi_empty(output_scale_shape, dl_float32, device); } using Tin = __nv_bfloat16; using Tout = __nv_fp8_e4m3; - const auto *input_ptr = reinterpret_cast(input.const_data_ptr()); - auto *output_ptr = reinterpret_cast(output_tensor.mutable_data_ptr()); - auto *output_scale_ptr = reinterpret_cast(output_scale_tensor.mutable_data_ptr()); + const auto *input_ptr = reinterpret_cast(input.data_ptr()); + auto *output_ptr = reinterpret_cast(output_tensor.data_ptr()); + auto *output_scale_ptr = reinterpret_cast(output_scale_tensor.data_ptr()); - const auto *num_per_expert_ptr = num_per_expert.const_data_ptr(); + const auto *num_per_expert_ptr = reinterpret_cast(num_per_expert.data_ptr()); - int num_experts = num_per_expert.size(0); - int num_total_tokens = input.size(0); + int num_experts = num_per_expert.shape().at(0); + int num_total_tokens = input.shape().at(0); int num_tokens_per_expert = num_total_tokens / num_experts; - int num_intermediate_size = input.size(1) / 2; + int num_intermediate_size = input.shape().at(1) / 2; masked_act_mul_and_blockwise_quant_async(output_ptr, output_scale_ptr, input_ptr, num_per_expert_ptr, num_total_tokens, num_intermediate_size, num_tokens_per_expert, stream); - return std::make_tuple(output_tensor, output_scale_tensor); + return tvm::ffi::Tuple(output_tensor, output_scale_tensor); } } // namespace activation } // namespace hpc -TORCH_LIBRARY_FRAGMENT(hpc, m) { - m.def( - "act_mul_and_quant(Tensor input, Tensor scale, bool use_bf16_mul, Tensor? output) -> " - "(Tensor)"); - m.impl("act_mul_and_quant", torch::kCUDA, &hpc::activation::act_mul_and_quant_entry); - m.def( - "masked_act_mul_and_quant(Tensor input, Tensor scale, Tensor num_per_expert, Tensor? output) " - "-> (Tensor)"); - m.impl("masked_act_mul_and_quant", torch::kCUDA, - &hpc::activation::masked_act_mul_and_quant_entry); - m.def( - "masked_act_mul_and_blockwise_quant(Tensor input, Tensor num_per_expert, Tensor? output, " - "Tensor? output_scale) -> (Tensor output, " - "Tensor output_scale)"); - m.impl("masked_act_mul_and_blockwise_quant", torch::kCUDA, - &hpc::activation::masked_act_mul_and_blockwise_quant_entry); -} +TVM_FFI_DLL_EXPORT_TYPED_FUNC(act_mul_and_quant, + hpc::activation::act_mul_and_quant_entry); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(masked_act_mul_and_quant, + hpc::activation::masked_act_mul_and_quant_entry); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(masked_act_mul_and_blockwise_quant, + hpc::activation::masked_act_mul_and_blockwise_quant_entry); diff --git a/src/attention/entry.cc b/src/attention/entry.cc index d6cff7c..e5d12e2 100644 --- a/src/attention/entry.cc +++ b/src/attention/entry.cc @@ -1,9 +1,13 @@ // Copyright (C) 2026 Tencent. -#include #include -#include -#include +#include + +#include +#include +#include + +#include "tvm_ffi_utils.h" #include "src/attention/decode/decode.h" #include "src/attention/prefill/prefill.h" @@ -11,46 +15,49 @@ namespace hpc { namespace attention { -torch::Tensor attention_prefill_bf16_entry(const torch::Tensor &q, const torch::Tensor &k, - const torch::Tensor &v, const torch::Tensor &seqlens_q, - const torch::Tensor &cu_seqlens_q, int64_t max_seqlens_q, - std::optional output) { - auto stream = at::cuda::getCurrentCUDAStream(q.get_device()); - TORCH_CHECK(q.device().is_cuda(), "q tensor must be cuda"); - TORCH_CHECK(k.device().is_cuda(), "k tensor must be cuda"); - TORCH_CHECK(v.device().is_cuda(), "v tensor must be cuda"); - TORCH_CHECK(seqlens_q.device().is_cuda(), "seqlens_q tensor must be cuda"); - TORCH_CHECK(cu_seqlens_q.device().is_cuda(), "cu_seqlens_q tensor must be cuda"); - - int total_seq_q = q.size(0); - int num_head_q = q.size(1); - int num_dim_qk = q.size(2); - - int num_head_kv = v.size(1); - int num_dim_v = v.size(2); - - int num_batch = seqlens_q.size(0); - - auto options = q.options().dtype(torch::kBFloat16); - torch::Tensor y; +tvm::ffi::Tensor attention_prefill_bf16_entry(const tvm::ffi::TensorView &q, + const tvm::ffi::TensorView &k, + const tvm::ffi::TensorView &v, + const tvm::ffi::TensorView &seqlens_q, + const tvm::ffi::TensorView &cu_seqlens_q, + int64_t max_seqlens_q, + tvm::ffi::Optional output) { + auto stream = TVM_FFI_GET_CUDA_STREAM(q); + TVM_FFI_CHECK_CUDA(q); + TVM_FFI_CHECK_CUDA(k); + TVM_FFI_CHECK_CUDA(v); + TVM_FFI_CHECK_CUDA(seqlens_q); + TVM_FFI_CHECK_CUDA(cu_seqlens_q); + + int total_seq_q = q.shape().at(0); + int num_head_q = q.shape().at(1); + int num_dim_qk = q.shape().at(2); + + int num_head_kv = v.shape().at(1); + int num_dim_v = v.shape().at(2); + + int num_batch = seqlens_q.shape().at(0); + + auto device = q.device(); + tvm::ffi::Tensor y; if (output.has_value()) { - y = output.value(); + y = tvm::ffi::Tensor(output.value()); } else { - y = torch::empty({total_seq_q, num_head_q, num_dim_v}, options); + y = tvm_ffi_empty({total_seq_q, num_head_q, num_dim_v}, dl_bfloat16, device); } int num_tmas = 4 * num_batch; - torch::Tensor tmas = torch::empty({num_tmas, 64}, options); + tvm::ffi::Tensor tmas = tvm_ffi_empty({num_tmas, 64}, dl_bfloat16, device); - const auto *q_ptr = q.const_data_ptr(); - const auto *k_ptr = k.const_data_ptr(); - const auto *v_ptr = v.const_data_ptr(); - const auto *seqlens_q_ptr = seqlens_q.const_data_ptr(); - const auto *cu_seqlens_q_ptr = cu_seqlens_q.const_data_ptr(); - void *tmas_ptr = tmas.mutable_data_ptr(); + const auto *q_ptr = q.data_ptr(); + const auto *k_ptr = k.data_ptr(); + const auto *v_ptr = v.data_ptr(); + const auto *seqlens_q_ptr = seqlens_q.data_ptr(); + const auto *cu_seqlens_q_ptr = cu_seqlens_q.data_ptr(); + void *tmas_ptr = tmas.data_ptr(); using T = __nv_bfloat16; - auto *y_ptr = reinterpret_cast(y.mutable_data_ptr()); + auto *y_ptr = reinterpret_cast(y.data_ptr()); int ldQ = q.stride(0); // num_head_q * num_dim_qk; int ldK = k.stride(0); // num_head_kv * num_dim_qk; @@ -64,54 +71,54 @@ torch::Tensor attention_prefill_bf16_entry(const torch::Tensor &q, const torch:: return y; } -torch::Tensor attention_with_kvcache_prefill_bf16_entry( - const torch::Tensor &q, const torch::Tensor &kcache, const torch::Tensor &vcache, - const torch::Tensor &cu_seqlens_q, const torch::Tensor block_ids, - const torch::Tensor seqlens_kvcache, int64_t max_seqlens_q, - std::optional output) { - auto stream = at::cuda::getCurrentCUDAStream(q.get_device()); - TORCH_CHECK(q.device().is_cuda(), "q tensor must be cuda"); - TORCH_CHECK(kcache.device().is_cuda(), "kcache tensor must be cuda"); - TORCH_CHECK(vcache.device().is_cuda(), "vcache tensor must be cuda"); - TORCH_CHECK(cu_seqlens_q.device().is_cuda(), "cu_seqlens_q tensor must be cuda"); - TORCH_CHECK(block_ids.device().is_cuda(), "block_ids tensor must be cuda"); - TORCH_CHECK(seqlens_kvcache.device().is_cuda(), "seqlens_kvcache tensor must be cuda"); +tvm::ffi::Tensor attention_with_kvcache_prefill_bf16_entry( + const tvm::ffi::TensorView &q, const tvm::ffi::TensorView &kcache, + const tvm::ffi::TensorView &vcache, const tvm::ffi::TensorView &cu_seqlens_q, + const tvm::ffi::TensorView &block_ids, const tvm::ffi::TensorView &seqlens_kvcache, + int64_t max_seqlens_q, tvm::ffi::Optional output) { + auto stream = TVM_FFI_GET_CUDA_STREAM(q); + TVM_FFI_CHECK_CUDA(q); + TVM_FFI_CHECK_CUDA(kcache); + TVM_FFI_CHECK_CUDA(vcache); + TVM_FFI_CHECK_CUDA(cu_seqlens_q); + TVM_FFI_CHECK_CUDA(block_ids); + TVM_FFI_CHECK_CUDA(seqlens_kvcache); - int total_seq_q = q.size(0); - int num_head_q = q.size(1); - int num_dim_qk = q.size(2); + int total_seq_q = q.shape().at(0); + int num_head_q = q.shape().at(1); + int num_dim_qk = q.shape().at(2); - int num_batch = cu_seqlens_q.size(0) - 1; + int num_batch = cu_seqlens_q.shape().at(0) - 1; - int num_kvcache_blocks = kcache.size(0); - int block_size = kcache.size(1); + int num_kvcache_blocks = kcache.shape().at(0); + int block_size = kcache.shape().at(1); - int num_head_kv = kcache.size(2); - int num_dim_v = vcache.size(3); + int num_head_kv = kcache.shape().at(2); + int num_dim_v = vcache.shape().at(3); - int num_seq_max_blocks = block_ids.size(1); + int num_seq_max_blocks = block_ids.shape().at(1); - auto options = q.options().dtype(torch::kBFloat16); - torch::Tensor y; + auto device = q.device(); + tvm::ffi::Tensor y; if (output.has_value()) { - y = output.value(); + y = tvm::ffi::Tensor(output.value()); } else { - y = torch::empty({total_seq_q, num_head_q, num_dim_v}, options); + y = tvm_ffi_empty({total_seq_q, num_head_q, num_dim_v}, dl_bfloat16, device); } int num_tmas = 2 * num_batch; - torch::Tensor tmas = torch::empty({num_tmas, 64}, options); + tvm::ffi::Tensor tmas = tvm_ffi_empty({num_tmas, 64}, dl_bfloat16, device); - const auto *q_ptr = q.const_data_ptr(); - const auto *kcache_ptr = kcache.const_data_ptr(); - const auto *vcache_ptr = vcache.const_data_ptr(); - const auto *cu_seqlens_q_ptr = cu_seqlens_q.const_data_ptr(); - const auto *block_ids_ptr = block_ids.const_data_ptr(); - const auto *seqlens_kvcache_ptr = seqlens_kvcache.const_data_ptr(); - void *tmas_ptr = tmas.mutable_data_ptr(); + const auto *q_ptr = q.data_ptr(); + const auto *kcache_ptr = kcache.data_ptr(); + const auto *vcache_ptr = vcache.data_ptr(); + const auto *cu_seqlens_q_ptr = cu_seqlens_q.data_ptr(); + const auto *block_ids_ptr = block_ids.data_ptr(); + const auto *seqlens_kvcache_ptr = seqlens_kvcache.data_ptr(); + void *tmas_ptr = tmas.data_ptr(); using T = __nv_bfloat16; - auto *y_ptr = reinterpret_cast(y.mutable_data_ptr()); + auto *y_ptr = reinterpret_cast(y.data_ptr()); int ldQ = q.stride(0); // num_head_q * num_dim_qk; int ldK = kcache.stride(0); // num_head_kv * num_dim_qk; @@ -126,60 +133,61 @@ torch::Tensor attention_with_kvcache_prefill_bf16_entry( return y; } -torch::Tensor attention_with_kvcache_prefill_fp8_entry( - const torch::Tensor &q, const torch::Tensor &kcache, const torch::Tensor &vcache, - const torch::Tensor &qkscale, const torch::Tensor &vscale, const torch::Tensor &cu_seqlens_q, - const torch::Tensor block_ids, const torch::Tensor seqlens_kvcache, int64_t max_seqlens_q, - std::optional output) { - auto stream = at::cuda::getCurrentCUDAStream(q.get_device()); - TORCH_CHECK(q.device().is_cuda(), "q tensor must be cuda"); - TORCH_CHECK(kcache.device().is_cuda(), "kcache tensor must be cuda"); - TORCH_CHECK(vcache.device().is_cuda(), "vcache tensor must be cuda"); - TORCH_CHECK(qkscale.device().is_cuda(), "qkscale tensor must be cuda"); - TORCH_CHECK(vscale.device().is_cuda(), "vscale tensor must be cuda"); - TORCH_CHECK(cu_seqlens_q.device().is_cuda(), "cu_seqlens_q tensor must be cuda"); - TORCH_CHECK(block_ids.device().is_cuda(), "block_ids tensor must be cuda"); - TORCH_CHECK(seqlens_kvcache.device().is_cuda(), "seqlens_kvcache tensor must be cuda"); +tvm::ffi::Tensor attention_with_kvcache_prefill_fp8_entry( + const tvm::ffi::TensorView &q, const tvm::ffi::TensorView &kcache, + const tvm::ffi::TensorView &vcache, const tvm::ffi::TensorView &qkscale, + const tvm::ffi::TensorView &vscale, const tvm::ffi::TensorView &cu_seqlens_q, + const tvm::ffi::TensorView &block_ids, const tvm::ffi::TensorView &seqlens_kvcache, + int64_t max_seqlens_q, tvm::ffi::Optional output) { + auto stream = TVM_FFI_GET_CUDA_STREAM(q); + TVM_FFI_CHECK_CUDA(q); + TVM_FFI_CHECK_CUDA(kcache); + TVM_FFI_CHECK_CUDA(vcache); + TVM_FFI_CHECK_CUDA(qkscale); + TVM_FFI_CHECK_CUDA(vscale); + TVM_FFI_CHECK_CUDA(cu_seqlens_q); + TVM_FFI_CHECK_CUDA(block_ids); + TVM_FFI_CHECK_CUDA(seqlens_kvcache); - int total_seq_q = q.size(0); - int num_head_q = q.size(1); - int num_dim_qk = q.size(2); + int total_seq_q = q.shape().at(0); + int num_head_q = q.shape().at(1); + int num_dim_qk = q.shape().at(2); - int num_batch = cu_seqlens_q.size(0) - 1; + int num_batch = cu_seqlens_q.shape().at(0) - 1; - int num_kvcache_blocks = kcache.size(0); - int block_size = kcache.size(1); + int num_kvcache_blocks = kcache.shape().at(0); + int block_size = kcache.shape().at(1); - int num_head_kv = kcache.size(2); - int num_dim_v = vcache.size(3); + int num_head_kv = kcache.shape().at(2); + int num_dim_v = vcache.shape().at(3); - int num_seq_max_blocks = block_ids.size(1); + int num_seq_max_blocks = block_ids.shape().at(1); - int max_seqlens_q_pad = qkscale.size(2); + int max_seqlens_q_pad = qkscale.shape().at(2); - auto options = q.options().dtype(torch::kBFloat16); - torch::Tensor y; + auto device = q.device(); + tvm::ffi::Tensor y; if (output.has_value()) { - y = output.value(); + y = tvm::ffi::Tensor(output.value()); } else { - y = torch::empty({total_seq_q, num_head_q, num_dim_v}, options); + y = tvm_ffi_empty({total_seq_q, num_head_q, num_dim_v}, dl_bfloat16, device); } int num_tmas = 2 * num_batch; - torch::Tensor tmas = torch::empty({num_tmas, 64}, options); - - const auto *q_ptr = q.const_data_ptr(); - const auto *kcache_ptr = kcache.const_data_ptr(); - const auto *vcache_ptr = vcache.const_data_ptr(); - const auto *qkscale_ptr = qkscale.const_data_ptr(); - const auto *vscale_ptr = vscale.const_data_ptr(); - const auto *cu_seqlens_q_ptr = cu_seqlens_q.const_data_ptr(); - const auto *block_ids_ptr = block_ids.const_data_ptr(); - const auto *seqlens_kvcache_ptr = seqlens_kvcache.const_data_ptr(); - void *tmas_ptr = tmas.mutable_data_ptr(); + tvm::ffi::Tensor tmas = tvm_ffi_empty({num_tmas, 64}, dl_bfloat16, device); + + const auto *q_ptr = q.data_ptr(); + const auto *kcache_ptr = kcache.data_ptr(); + const auto *vcache_ptr = vcache.data_ptr(); + const auto *qkscale_ptr = qkscale.data_ptr(); + const auto *vscale_ptr = vscale.data_ptr(); + const auto *cu_seqlens_q_ptr = cu_seqlens_q.data_ptr(); + const auto *block_ids_ptr = block_ids.data_ptr(); + const auto *seqlens_kvcache_ptr = seqlens_kvcache.data_ptr(); + void *tmas_ptr = tmas.data_ptr(); using T = __nv_bfloat16; - auto *y_ptr = reinterpret_cast(y.mutable_data_ptr()); + auto *y_ptr = reinterpret_cast(y.data_ptr()); int ldQ = q.stride(0); // num_head_q * num_dim_qk; int ldK = kcache.stride(0); // num_head_kv * num_dim_qk; @@ -195,59 +203,58 @@ torch::Tensor attention_with_kvcache_prefill_fp8_entry( return y; } -torch::Tensor attention_decode_bf16_entry(const torch::Tensor &q, torch::Tensor &kcache, - torch::Tensor &vcache, const torch::Tensor &block_ids, - const torch::Tensor &num_seq_kvcache, - bool new_kv_included, bool use_splitk, - std::optional output) { - auto stream = at::cuda::getCurrentCUDAStream(q.get_device()); - - TORCH_CHECK(q.device().is_cuda(), "q tensor must be cuda"); - TORCH_CHECK(kcache.device().is_cuda(), "v tensor must be cuda"); - TORCH_CHECK(vcache.device().is_cuda(), "v tensor must be cuda"); - TORCH_CHECK(block_ids.device().is_cuda(), "v tensor must be cuda"); - TORCH_CHECK(block_ids.is_contiguous(), "block_ids tensor must be contiguous"); - TORCH_CHECK(num_seq_kvcache.is_contiguous(), "num_seq_kvcache tensor must be contiguous"); - TORCH_CHECK(block_ids.scalar_type() == torch::kInt32, "block_ids dtype must be int32"); - TORCH_CHECK(num_seq_kvcache.scalar_type() == torch::kInt32, - "num_seq_kvcache dtype must be int32"); - - int num_batch = num_seq_kvcache.size(0); - int num_seq_q = q.size(0) / num_batch; - TORCH_CHECK(num_seq_q == 1, "num_seq_q must be 1"); - int num_head_q = q.size(1); - int num_dim_qk = q.size(2); - - int num_kvcache_blocks = kcache.size(0); - int block_size = kcache.size(1); - - int num_head_k = kcache.size(2); - int num_head_v = vcache.size(2); - int num_dim_v = vcache.size(3); - - int num_seq_max_blocks = block_ids.size(1); - - const auto *q_ptr = q.const_data_ptr(); - auto *kcache_ptr = kcache.mutable_data_ptr(); - auto *vcache_ptr = vcache.mutable_data_ptr(); - const int *block_ids_ptr = block_ids.const_data_ptr(); - const int *num_seq_kvcache_ptr = num_seq_kvcache.const_data_ptr(); - - auto options = q.options().dtype(torch::kBFloat16); - torch::Tensor y; +tvm::ffi::Tensor attention_decode_bf16_entry(const tvm::ffi::TensorView &q, + const tvm::ffi::TensorView &kcache, + const tvm::ffi::TensorView &vcache, + const tvm::ffi::TensorView &block_ids, + const tvm::ffi::TensorView &num_seq_kvcache, + bool new_kv_included, bool use_splitk, + tvm::ffi::Optional output) { + auto stream = TVM_FFI_GET_CUDA_STREAM(q); + + TVM_FFI_CHECK_CUDA(q); + TVM_FFI_CHECK_CUDA(kcache); + TVM_FFI_CHECK_CUDA(vcache); + TVM_FFI_CHECK_CUDA(block_ids); + TVM_FFI_CHECK_CONTIGUOUS(block_ids); + TVM_FFI_CHECK_CONTIGUOUS(num_seq_kvcache); + + int num_batch = num_seq_kvcache.shape().at(0); + int num_seq_q = q.shape().at(0) / num_batch; + TVM_FFI_ICHECK(num_seq_q == 1) << "num_seq_q must be 1"; + int num_head_q = q.shape().at(1); + int num_dim_qk = q.shape().at(2); + + int num_kvcache_blocks = kcache.shape().at(0); + int block_size = kcache.shape().at(1); + + int num_head_k = kcache.shape().at(2); + int num_head_v = vcache.shape().at(2); + int num_dim_v = vcache.shape().at(3); + + int num_seq_max_blocks = block_ids.shape().at(1); + + const auto *q_ptr = q.data_ptr(); + auto *kcache_ptr = const_cast(kcache.data_ptr()); + auto *vcache_ptr = const_cast(vcache.data_ptr()); + const int *block_ids_ptr = reinterpret_cast(block_ids.data_ptr()); + const int *num_seq_kvcache_ptr = reinterpret_cast(num_seq_kvcache.data_ptr()); + + auto device = q.device(); + tvm::ffi::Tensor y; if (output.has_value()) { - y = output.value(); + y = tvm::ffi::Tensor(output.value()); } else { - y = torch::empty({num_batch * num_seq_q, num_head_q, num_dim_v}, options); + y = tvm_ffi_empty({num_batch * num_seq_q, num_head_q, num_dim_v}, dl_bfloat16, device); } - torch::Tensor lse; - torch::Tensor split_out; + tvm::ffi::Tensor lse; + tvm::ffi::Tensor split_out; - int splitk = 0; // small batch increase splitk number to maximize sm usage. // 1. batch <= 32. split one request seqlenk to 16 parts. // 2. batch > 32. split one request seqlenk to 4 parts. + int splitk = 0; if (use_splitk) { if (num_batch <= 32) { splitk = 16; @@ -256,16 +263,18 @@ torch::Tensor attention_decode_bf16_entry(const torch::Tensor &q, torch::Tensor } } + void *lse_ptr = nullptr; + void *split_out_ptr = nullptr; + if (splitk > 0) { - lse = torch::empty({num_batch, splitk, num_head_q}, q.options().dtype(torch::kFloat32)); - split_out = torch::empty({num_batch, splitk, num_head_q, num_dim_v}, - q.options().dtype(torch::kFloat32)); + lse = tvm_ffi_empty({num_batch, splitk, num_head_q}, dl_float32, device); + split_out = + tvm_ffi_empty({num_batch, splitk, num_head_q, num_dim_v}, dl_float32, device); + lse_ptr = lse.data_ptr(); + split_out_ptr = split_out.data_ptr(); } - auto *lse_ptr = splitk > 0 ? lse.mutable_data_ptr() : nullptr; - auto *split_out_ptr = splitk > 0 ? split_out.mutable_data_ptr() : nullptr; - - auto *y_ptr = y.mutable_data_ptr(); + auto *y_ptr = y.data_ptr(); int ldQ = q.stride(0); // num_head_q * num_dim_qk; int ldK = kcache.stride(0); @@ -275,76 +284,70 @@ torch::Tensor attention_decode_bf16_entry(const torch::Tensor &q, torch::Tensor bool running = attention_decode_bf16_async( y_ptr, lse_ptr, split_out_ptr, q_ptr, kcache_ptr, vcache_ptr, block_ids_ptr, num_seq_kvcache_ptr, new_kv_included, splitk, num_batch, num_head_q, num_head_k, num_head_v, - num_dim_qk, num_dim_v, num_kvcache_blocks, block_size, num_seq_max_blocks, ldY, ldQ, ldK, ldV, - stream); + num_dim_qk, num_dim_v, num_kvcache_blocks, block_size, num_seq_max_blocks, ldY, ldQ, ldK, + ldV, stream); - TORCH_CHECK(running, "attn decode kernel launch failed!"); + TVM_FFI_ICHECK(running) << "attn decode kernel launch failed!"; return y; } -torch::Tensor attention_decode_fp8_entry(const torch::Tensor &q, torch::Tensor &kcache, - torch::Tensor &vcache, const torch::Tensor &block_ids, - const torch::Tensor &num_seq_kvcache, - const torch::Tensor &qscale, const torch::Tensor &kscale, - const torch::Tensor &vscale, bool new_kv_included, - bool use_splitk, std::optional split_flag, - std::optional output) { - auto stream = at::cuda::getCurrentCUDAStream(q.get_device()); - - TORCH_CHECK(q.device().is_cuda(), "q tensor must be cuda"); - TORCH_CHECK(kcache.device().is_cuda(), "v tensor must be cuda"); - TORCH_CHECK(vcache.device().is_cuda(), "v tensor must be cuda"); - TORCH_CHECK(block_ids.device().is_cuda(), "v tensor must be cuda"); - TORCH_CHECK(block_ids.is_contiguous(), "block_ids tensor must be contiguous"); - TORCH_CHECK(num_seq_kvcache.is_contiguous(), "num_seq_kvcache tensor must be contiguous"); - TORCH_CHECK(q.scalar_type() == torch::kFloat8_e4m3fn, "q dtype must be fp8_e4m3fn"); - TORCH_CHECK(kcache.dtype().itemsize() == 1, "kcache tensor element type size must be fp8_e4m3"); - TORCH_CHECK(vcache.dtype().itemsize() == 1, "vcache tensor element type size must be fp8_e4m3"); - TORCH_CHECK(block_ids.scalar_type() == torch::kInt32, "block_ids dtype must be int32"); - TORCH_CHECK(num_seq_kvcache.scalar_type() == torch::kInt32, - "num_seq_kvcache dtype must be int32"); - - int num_batch = num_seq_kvcache.size(0); - int num_seq_q = q.size(0) / num_batch; - TORCH_CHECK(num_seq_q == 1, "num_seq_q must be 1"); - int num_head_q = q.size(1); - int num_dim_qk = q.size(2); - - int num_kvcache_blocks = kcache.size(0); - int block_size = kcache.size(1); - - int num_head_k = kcache.size(2); - int num_head_v = vcache.size(2); - int num_dim_v = vcache.size(3); - - int num_seq_max_blocks = block_ids.size(1); +tvm::ffi::Tensor attention_decode_fp8_entry( + const tvm::ffi::TensorView &q, const tvm::ffi::TensorView &kcache, + const tvm::ffi::TensorView &vcache, const tvm::ffi::TensorView &block_ids, + const tvm::ffi::TensorView &num_seq_kvcache, const tvm::ffi::TensorView &qscale, + const tvm::ffi::TensorView &kscale, const tvm::ffi::TensorView &vscale, bool new_kv_included, + bool use_splitk, tvm::ffi::Optional split_flag, + tvm::ffi::Optional output) { + auto stream = TVM_FFI_GET_CUDA_STREAM(q); + + TVM_FFI_CHECK_CUDA(q); + TVM_FFI_CHECK_CUDA(kcache); + TVM_FFI_CHECK_CUDA(vcache); + TVM_FFI_CHECK_CUDA(block_ids); + TVM_FFI_CHECK_CONTIGUOUS(block_ids); + TVM_FFI_CHECK_CONTIGUOUS(num_seq_kvcache); + + int num_batch = num_seq_kvcache.shape().at(0); + int num_seq_q = q.shape().at(0) / num_batch; + TVM_FFI_ICHECK(num_seq_q == 1) << "num_seq_q must be 1"; + int num_head_q = q.shape().at(1); + int num_dim_qk = q.shape().at(2); + + int num_kvcache_blocks = kcache.shape().at(0); + int block_size = kcache.shape().at(1); + + int num_head_k = kcache.shape().at(2); + int num_head_v = vcache.shape().at(2); + int num_dim_v = vcache.shape().at(3); + + int num_seq_max_blocks = block_ids.shape().at(1); int qscale_pad_stride = qscale.stride(0); - const auto *q_ptr = q.const_data_ptr(); - auto *kcache_ptr = kcache.mutable_data_ptr(); - auto *vcache_ptr = vcache.mutable_data_ptr(); - const int *block_ids_ptr = block_ids.const_data_ptr(); - const int *num_seq_kvcache_ptr = num_seq_kvcache.const_data_ptr(); - const float *qscale_ptr = qscale.const_data_ptr(); - const float *kscale_ptr = kscale.const_data_ptr(); - const float *vscale_ptr = vscale.const_data_ptr(); - - auto options = q.options().dtype(torch::kBFloat16); - torch::Tensor y; + const auto *q_ptr = q.data_ptr(); + auto *kcache_ptr = const_cast(kcache.data_ptr()); + auto *vcache_ptr = const_cast(vcache.data_ptr()); + const int *block_ids_ptr = reinterpret_cast(block_ids.data_ptr()); + const int *num_seq_kvcache_ptr = reinterpret_cast(num_seq_kvcache.data_ptr()); + const float *qscale_ptr = reinterpret_cast(qscale.data_ptr()); + const float *kscale_ptr = reinterpret_cast(kscale.data_ptr()); + const float *vscale_ptr = reinterpret_cast(vscale.data_ptr()); + + auto device = q.device(); + tvm::ffi::Tensor y; if (output.has_value()) { - y = output.value(); + y = tvm::ffi::Tensor(output.value()); } else { - y = torch::empty({num_batch * num_seq_q, num_head_q, num_dim_v}, options); + y = tvm_ffi_empty({num_batch * num_seq_q, num_head_q, num_dim_v}, dl_bfloat16, device); } - torch::Tensor lse; - torch::Tensor split_out; + tvm::ffi::Tensor lse; + tvm::ffi::Tensor split_out_tensor; + // small batch increase splitk number to maximize sm usage. int splitk = 0; int splitk_min_len = 0; - // small batch increase splitk number to maximize sm usage. if (use_splitk) { if (num_batch <= 32) { splitk = 4; @@ -364,25 +367,27 @@ torch::Tensor attention_decode_fp8_entry(const torch::Tensor &q, torch::Tensor & consumers = 2; } - torch::Tensor split_flag_tensor; + tvm::ffi::Tensor split_flag_tensor; if (split_flag.has_value()) { - split_flag_tensor = split_flag.value(); + split_flag_tensor = tvm::ffi::Tensor(split_flag.value()); } else { - split_flag_tensor = torch::zeros({num_batch, num_head_k}, q.options().dtype(torch::kInt32)); + split_flag_tensor = tvm_ffi_zeros({num_batch, num_head_k}, dl_int32, device); } + void *lse_ptr = nullptr; + void *split_out_ptr = nullptr; + if (splitk > 0) { - lse = torch::empty({num_batch, splitk * consumers, num_head_q}, - q.options().dtype(torch::kFloat32)); - split_out = torch::empty({num_batch, splitk * consumers, num_head_q, num_dim_v}, - q.options().dtype(torch::kFloat32)); + lse = tvm_ffi_empty({num_batch, splitk * consumers, num_head_q}, dl_float32, device); + split_out_tensor = + tvm_ffi_empty({num_batch, splitk * consumers, num_head_q, num_dim_v}, dl_float32, device); + lse_ptr = lse.data_ptr(); + split_out_ptr = split_out_tensor.data_ptr(); } - auto *lse_ptr = splitk > 0 ? lse.mutable_data_ptr() : nullptr; - auto *split_out_ptr = splitk > 0 ? split_out.mutable_data_ptr() : nullptr; - auto *split_flag_ptr = split_flag_tensor.mutable_data_ptr(); + auto *split_flag_ptr = reinterpret_cast(split_flag_tensor.data_ptr()); - auto *y_ptr = y.mutable_data_ptr(); + auto *y_ptr = y.data_ptr(); int ldQ = q.stride(0); // num_head_q * num_dim_qk; int ldK = kcache.stride(0); @@ -396,7 +401,7 @@ torch::Tensor attention_decode_fp8_entry(const torch::Tensor &q, torch::Tensor & num_dim_v, num_kvcache_blocks, block_size, num_seq_max_blocks, qscale_pad_stride, ldY, ldQ, ldK, ldV, stream); - TORCH_CHECK(running, "attn decode kernel launch failed!"); + TVM_FFI_ICHECK(running) << "attn decode kernel launch failed!"; return y; } @@ -404,34 +409,13 @@ torch::Tensor attention_decode_fp8_entry(const torch::Tensor &q, torch::Tensor & } // namespace attention } // namespace hpc -TORCH_LIBRARY_FRAGMENT(hpc, m) { - m.def( - "attention_prefill_bf16(Tensor q, Tensor k, Tensor v, Tensor seqlens_q, Tensor cu_seqlens_q, " - "int max_seqlens_q, Tensor? output) -> (Tensor)"); - m.impl("attention_prefill_bf16", torch::kCUDA, &hpc::attention::attention_prefill_bf16_entry); - - m.def( - "attention_with_kvcache_prefill_bf16(Tensor q, Tensor kcache, Tensor vcache," - "Tensor cu_seqlens_q, " - "Tensor block_ids, Tensor num_seq_kvcache, int max_seqlens_q, Tensor? output) -> (Tensor)"); - m.impl("attention_with_kvcache_prefill_bf16", torch::kCUDA, - &hpc::attention::attention_with_kvcache_prefill_bf16_entry); - - m.def( - "attention_with_kvcache_prefill_fp8(Tensor q, Tensor kcache, Tensor vcache," - "Tensor qkscale, Tensor vscale, Tensor cu_seqlens_q," - "Tensor block_ids, Tensor num_seq_kvcache, int max_seqlens_q, Tensor? output) -> (Tensor)"); - m.impl("attention_with_kvcache_prefill_fp8", torch::kCUDA, - &hpc::attention::attention_with_kvcache_prefill_fp8_entry); - - m.def( - "attention_decode_bf16(Tensor q, Tensor! kcache, Tensor! vcache, Tensor block_ids, Tensor " - "num_seq_kvcache, bool new_kv_included, bool use_splitk, Tensor? output) -> (Tensor)"); - m.impl("attention_decode_bf16", torch::kCUDA, &hpc::attention::attention_decode_bf16_entry); - - m.def( - "attention_decode_fp8(Tensor q, Tensor! kcache, Tensor! vcache, Tensor block_ids, Tensor " - "num_seq_kvcache, Tensor qscale, Tensor kscale, Tensor vscale, bool new_kv_included, bool " - "use_splitk, Tensor? split_flag, Tensor? output) -> (Tensor)"); - m.impl("attention_decode_fp8", torch::kCUDA, &hpc::attention::attention_decode_fp8_entry); -} +TVM_FFI_DLL_EXPORT_TYPED_FUNC(attention_prefill_bf16, + hpc::attention::attention_prefill_bf16_entry); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(attention_with_kvcache_prefill_bf16, + hpc::attention::attention_with_kvcache_prefill_bf16_entry); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(attention_with_kvcache_prefill_fp8, + hpc::attention::attention_with_kvcache_prefill_fp8_entry); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(attention_decode_bf16, + hpc::attention::attention_decode_bf16_entry); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(attention_decode_fp8, + hpc::attention::attention_decode_fp8_entry); diff --git a/src/fuse_moe/entry.cc b/src/fuse_moe/entry.cc index 9e97c19..f229725 100644 --- a/src/fuse_moe/entry.cc +++ b/src/fuse_moe/entry.cc @@ -1,63 +1,71 @@ // Copyright (C) 2026 Tencent. -#include #include -#include -#include +#include +#include +#include + +#include #include +#include "tvm_ffi_utils.h" + #include "src/fuse_moe/fuse_moe.h" namespace hpc { namespace fuse_moe { -std::tuple -count_and_gather_entry(const torch::Tensor &x, const torch::Tensor &topk_ids, - const int64_t num_expert, const int64_t rank_ep, - const int64_t intermediate_size, const int64_t num_seq_per_group_avg) { - auto stream = at::cuda::getCurrentCUDAStream(x.get_device()); - TORCH_CHECK(x.device().is_cuda(), "x tensor must be cuda"); - TORCH_CHECK(topk_ids.device().is_cuda(), "topk_ids tensor must be cuda"); - TORCH_CHECK(x.is_contiguous(), "x tensor a must be contiguous"); - TORCH_CHECK(topk_ids.is_contiguous(), "topk_ids tensor a must be contiguous"); - TORCH_CHECK(x.size(0) == topk_ids.size(0), "x and topk_ids must share the same k"); - - int num_seq = x.size(0); - int hidden_size = x.size(1); - int num_topk = topk_ids.size(1); - - auto options = x.options(); - torch::Tensor gate_up_input = torch::empty({num_seq * num_topk, hidden_size}, options); - torch::Tensor gate_up_output = - torch::empty({num_seq * num_topk, intermediate_size}, options.dtype(torch::kBFloat16)); - torch::Tensor down_input = torch::empty({num_seq * num_topk, intermediate_size / 2}, options); - torch::Tensor down_output = - torch::empty({num_seq * num_topk, hidden_size}, options.dtype(torch::kBFloat16)); - - torch::Tensor topk_pos = torch::empty({num_seq, num_topk}, options.dtype(torch::kInt32)); - torch::Tensor seqlens = torch::zeros({num_expert}, options.dtype(torch::kInt32)); - torch::Tensor cu_seqlens = torch::empty({num_expert + 1}, options.dtype(torch::kInt32)); - torch::Tensor tiles = torch::empty({num_expert}, options.dtype(torch::kInt32)); - torch::Tensor cu_tiles = torch::empty({num_expert + 1}, options.dtype(torch::kInt32)); - torch::Tensor gate_up_tmas = torch::empty({num_expert * 2, 128}, options.dtype(torch::kInt8)); - torch::Tensor dowm_tmas = torch::empty({num_expert * 2, 128}, options.dtype(torch::kInt8)); - - const auto *x_ptr = x.const_data_ptr(); - const auto *topk_ids_ptr = topk_ids.const_data_ptr(); - - auto *gate_up_input_ptr = gate_up_input.mutable_data_ptr(); - auto *gate_up_output_ptr = gate_up_output.mutable_data_ptr(); - auto *down_input_ptr = down_input.mutable_data_ptr(); - auto *down_output_ptr = down_output.mutable_data_ptr(); - auto *topk_pos_ptr = topk_pos.mutable_data_ptr(); - auto *seqlens_ptr = seqlens.mutable_data_ptr(); - auto *cu_seqlens_ptr = cu_seqlens.mutable_data_ptr(); - auto *tiles_ptr = tiles.mutable_data_ptr(); - auto *cu_tiles_ptr = cu_tiles.mutable_data_ptr(); - auto *gate_up_tmas_ptr = gate_up_tmas.mutable_data_ptr(); - auto *dowm_tmas_ptr = dowm_tmas.mutable_data_ptr(); +tvm::ffi::Tuple +count_and_gather_entry(const tvm::ffi::TensorView &x, const tvm::ffi::TensorView &topk_ids, + int64_t num_expert, int64_t rank_ep, int64_t intermediate_size, + int64_t num_seq_per_group_avg) { + auto stream = TVM_FFI_GET_CUDA_STREAM(x); + TVM_FFI_CHECK_CUDA(x); + TVM_FFI_CHECK_CUDA(topk_ids); + TVM_FFI_CHECK_CONTIGUOUS(x); + TVM_FFI_CHECK_CONTIGUOUS(topk_ids); + TVM_FFI_ICHECK(x.shape().at(0) == topk_ids.shape().at(0)) + << "x and topk_ids must share the same k"; + + int num_seq = x.shape().at(0); + int hidden_size = x.shape().at(1); + int num_topk = topk_ids.shape().at(1); + + auto device = x.device(); + tvm::ffi::Tensor gate_up_input = + tvm_ffi_empty({num_seq * num_topk, hidden_size}, x.dtype(), device); + tvm::ffi::Tensor gate_up_output = + tvm_ffi_empty({num_seq * num_topk, intermediate_size}, dl_bfloat16, device); + tvm::ffi::Tensor down_input = + tvm_ffi_empty({num_seq * num_topk, intermediate_size / 2}, x.dtype(), device); + tvm::ffi::Tensor down_output = + tvm_ffi_empty({num_seq * num_topk, hidden_size}, dl_bfloat16, device); + + tvm::ffi::Tensor topk_pos = tvm_ffi_empty({num_seq, num_topk}, dl_int32, device); + tvm::ffi::Tensor seqlens = tvm_ffi_zeros({num_expert}, dl_int32, device); + tvm::ffi::Tensor cu_seqlens = tvm_ffi_empty({num_expert + 1}, dl_int32, device); + tvm::ffi::Tensor tiles = tvm_ffi_empty({num_expert}, dl_int32, device); + tvm::ffi::Tensor cu_tiles = tvm_ffi_empty({num_expert + 1}, dl_int32, device); + tvm::ffi::Tensor gate_up_tmas = tvm_ffi_empty({num_expert * 2, 128}, dl_int8, device); + tvm::ffi::Tensor dowm_tmas = tvm_ffi_empty({num_expert * 2, 128}, dl_int8, device); + + const auto *x_ptr = x.data_ptr(); + const auto *topk_ids_ptr = topk_ids.data_ptr(); + + auto *gate_up_input_ptr = gate_up_input.data_ptr(); + auto *gate_up_output_ptr = gate_up_output.data_ptr(); + auto *down_input_ptr = down_input.data_ptr(); + auto *down_output_ptr = down_output.data_ptr(); + auto *topk_pos_ptr = topk_pos.data_ptr(); + auto *seqlens_ptr = seqlens.data_ptr(); + auto *cu_seqlens_ptr = cu_seqlens.data_ptr(); + auto *tiles_ptr = tiles.data_ptr(); + auto *cu_tiles_ptr = cu_tiles.data_ptr(); + auto *gate_up_tmas_ptr = gate_up_tmas.data_ptr(); + auto *dowm_tmas_ptr = dowm_tmas.data_ptr(); count_and_gather_async(gate_up_input_ptr, gate_up_output_ptr, down_input_ptr, down_output_ptr, x_ptr, topk_ids_ptr, topk_pos_ptr, seqlens_ptr, cu_seqlens_ptr, @@ -65,52 +73,49 @@ count_and_gather_entry(const torch::Tensor &x, const torch::Tensor &topk_ids, hidden_size, intermediate_size, num_topk, num_expert, rank_ep, num_seq_per_group_avg, stream); - return std::make_tuple(gate_up_input, gate_up_output, topk_pos, seqlens, cu_seqlens, tiles, - cu_tiles, gate_up_tmas, dowm_tmas); + return tvm::ffi::Tuple(gate_up_input, gate_up_output, topk_pos, seqlens, + cu_seqlens, tiles, cu_tiles, gate_up_tmas, dowm_tmas); } -torch::Tensor reduce_entry(const torch::Tensor &x, const torch::Tensor &topk_pos, - const torch::Tensor &topk_scale, - const std::optional &shared_output) { - auto stream = at::cuda::getCurrentCUDAStream(x.get_device()); - TORCH_CHECK(x.device().is_cuda(), "x tensor must be cuda"); - TORCH_CHECK(topk_pos.device().is_cuda(), "topk_pos tensor must be cuda"); - TORCH_CHECK(topk_scale.device().is_cuda(), "topk_scale tensor must be cuda"); - TORCH_CHECK(x.is_contiguous(), "x tensor must be contiguous"); - TORCH_CHECK(topk_pos.is_contiguous(), "topk_pos tensor must be contiguous"); - TORCH_CHECK(topk_scale.is_contiguous(), "topk_scale tensor must be contiguous"); - TORCH_CHECK(topk_pos.size(0) == topk_scale.size(0), - "topk_pos and topk_scale must share the same num_seq"); - TORCH_CHECK(topk_pos.size(1) == topk_scale.size(1), - "topk_pos and topk_scale must share the same num_topk"); +tvm::ffi::Tensor reduce_entry(const tvm::ffi::TensorView &x, const tvm::ffi::TensorView &topk_pos, + const tvm::ffi::TensorView &topk_scale, + tvm::ffi::Optional shared_output) { + auto stream = TVM_FFI_GET_CUDA_STREAM(x); + TVM_FFI_CHECK_CUDA(x); + TVM_FFI_CHECK_CUDA(topk_pos); + TVM_FFI_CHECK_CUDA(topk_scale); + TVM_FFI_CHECK_CONTIGUOUS(x); + TVM_FFI_CHECK_CONTIGUOUS(topk_pos); + TVM_FFI_CHECK_CONTIGUOUS(topk_scale); + TVM_FFI_ICHECK(topk_pos.shape().at(0) == topk_scale.shape().at(0)) + << "topk_pos and topk_scale must share the same num_seq"; + TVM_FFI_ICHECK(topk_pos.shape().at(1) == topk_scale.shape().at(1)) + << "topk_pos and topk_scale must share the same num_topk"; const void *shared_output_ptr = nullptr; if (shared_output.has_value()) { - const auto shared_output_tensor = shared_output.value(); - TORCH_CHECK(shared_output_tensor.device().is_cuda(), "shared_output tensor must be cuda"); - TORCH_CHECK(shared_output_tensor.is_contiguous(), "shared_output tensor must be contiguous"); - TORCH_CHECK(shared_output_tensor.dtype() == torch::kBFloat16, - "shared_output tensor dtype must be bfloat16"); - TORCH_CHECK( - shared_output_tensor.size(0) == x.size(0) && shared_output_tensor.size(1) == x.size(1), - "shared_output tensor shape must be same as x tensor"); - shared_output_ptr = shared_output_tensor.const_data_ptr(); + const auto &shared_output_tensor = shared_output.value(); + TVM_FFI_CHECK_CUDA(shared_output_tensor); + TVM_FFI_CHECK_CONTIGUOUS(shared_output_tensor); + shared_output_ptr = shared_output_tensor.data_ptr(); } - int total_num_seq = x.size(0); - int hidden_size = x.size(1); - int num_seq = topk_pos.size(0); - int num_topk = topk_pos.size(1); - TORCH_CHECK(num_topk <= 128, "num_topk must less than or equal to 128"); + int total_num_seq = x.shape().at(0); + int hidden_size = x.shape().at(1); + int num_seq = topk_pos.shape().at(0); + int num_topk = topk_pos.shape().at(1); + TVM_FFI_ICHECK(num_topk <= 128) << "num_topk must less than or equal to 128"; - auto options = x.options(); - torch::Tensor y = torch::empty({num_seq, hidden_size}, options.dtype(torch::kBFloat16)); + auto device = x.device(); + tvm::ffi::Tensor y = tvm_ffi_empty({num_seq, hidden_size}, dl_bfloat16, device); - const auto *x_ptr = x.const_data_ptr(); - const auto *topk_pos_ptr = topk_pos.const_data_ptr(); - const auto *topk_scale_ptr = topk_scale.const_data_ptr(); + const auto *x_ptr = x.data_ptr(); + const auto *topk_pos_ptr = topk_pos.data_ptr(); + const auto *topk_scale_ptr = topk_scale.data_ptr(); - auto *y_ptr = y.mutable_data_ptr(); + auto *y_ptr = y.data_ptr(); reduce_async(y_ptr, x_ptr, topk_pos_ptr, topk_scale_ptr, shared_output_ptr, total_num_seq, num_seq, hidden_size, num_topk, stream); @@ -118,98 +123,98 @@ torch::Tensor reduce_entry(const torch::Tensor &x, const torch::Tensor &topk_pos return y; } -torch::Tensor fuse_moe_pertensor_fp8_entry( - const torch::Tensor &x, const torch::Tensor &gate_up_weight, const torch::Tensor &down_weight, - const torch::Tensor &gate_up_scale, const torch::Tensor &down_scale, - const torch::Tensor &act_and_mul_scale, const torch::Tensor &topk_ids, - const torch::Tensor &topk_scale, const std::optional &shared_output, - int64_t rank_ep, int64_t num_expert_total, bool use_bf16_mul) { - auto stream = at::cuda::getCurrentCUDAStream(x.get_device()); - - TORCH_CHECK(x.device().is_cuda(), "x tensor must be cuda"); - TORCH_CHECK(gate_up_weight.device().is_cuda(), "gate_up_weight tensor must be cuda"); - TORCH_CHECK(gate_up_scale.device().is_cuda(), "gate_up_scale tensor must be cuda"); - TORCH_CHECK(down_scale.device().is_cuda(), "down_scale tensor must be cuda"); - TORCH_CHECK(act_and_mul_scale.device().is_cuda(), "act_and_mul_scale tensor must be cuda"); - TORCH_CHECK(topk_ids.device().is_cuda(), "topk_ids tensor must be cuda"); - TORCH_CHECK(topk_scale.device().is_cuda(), "topk_scale tensor must be cuda"); - - TORCH_CHECK(x.is_contiguous(), "x tensor must be contiguous"); - TORCH_CHECK(gate_up_weight.is_contiguous(), "gate_up_weight tensor must be contiguous"); - TORCH_CHECK(gate_up_scale.is_contiguous(), "gate_up_scale tensor must be contiguous"); - TORCH_CHECK(down_weight.is_contiguous(), "down_weight tensor must be contiguous"); - TORCH_CHECK(down_scale.is_contiguous(), "down_scale tensor must be contiguous"); - TORCH_CHECK(topk_ids.is_contiguous(), "topk_ids tensor must be contiguous"); - TORCH_CHECK(topk_scale.is_contiguous(), "topk_scale tensor must be contiguous"); - - TORCH_CHECK(x.size(0) == topk_ids.size(0), "x and topk_ids must share the same num_seq"); - TORCH_CHECK(topk_ids.size(0) == topk_scale.size(0), - "topk_ids and topk_scale must share the same num_seq"); - TORCH_CHECK(topk_ids.size(1) == topk_scale.size(1), - "topk_ids and topk_scale must share the same num_topk"); - TORCH_CHECK(x.size(1) == gate_up_weight.size(2), "x and weight must share the same k"); - TORCH_CHECK(gate_up_weight.size(0) == down_weight.size(0), - "gate_up_weight and down_weight must share the same num_expert"); +tvm::ffi::Tensor fuse_moe_pertensor_fp8_entry( + const tvm::ffi::TensorView &x, const tvm::ffi::TensorView &gate_up_weight, + const tvm::ffi::TensorView &down_weight, const tvm::ffi::TensorView &gate_up_scale, + const tvm::ffi::TensorView &down_scale, const tvm::ffi::TensorView &act_and_mul_scale, + const tvm::ffi::TensorView &topk_ids, const tvm::ffi::TensorView &topk_scale, + tvm::ffi::Optional shared_output, int64_t rank_ep, + int64_t num_expert_total, bool use_bf16_mul) { + auto stream = TVM_FFI_GET_CUDA_STREAM(x); + + TVM_FFI_CHECK_CUDA(x); + TVM_FFI_CHECK_CUDA(gate_up_weight); + TVM_FFI_CHECK_CUDA(gate_up_scale); + TVM_FFI_CHECK_CUDA(down_scale); + TVM_FFI_CHECK_CUDA(act_and_mul_scale); + TVM_FFI_CHECK_CUDA(topk_ids); + TVM_FFI_CHECK_CUDA(topk_scale); + + TVM_FFI_CHECK_CONTIGUOUS(x); + TVM_FFI_CHECK_CONTIGUOUS(gate_up_weight); + TVM_FFI_CHECK_CONTIGUOUS(gate_up_scale); + TVM_FFI_CHECK_CONTIGUOUS(down_weight); + TVM_FFI_CHECK_CONTIGUOUS(down_scale); + TVM_FFI_CHECK_CONTIGUOUS(topk_ids); + TVM_FFI_CHECK_CONTIGUOUS(topk_scale); + + TVM_FFI_ICHECK(x.shape().at(0) == topk_ids.shape().at(0)) + << "x and topk_ids must share the same num_seq"; + TVM_FFI_ICHECK(topk_ids.shape().at(0) == topk_scale.shape().at(0)) + << "topk_ids and topk_scale must share the same num_seq"; + TVM_FFI_ICHECK(topk_ids.shape().at(1) == topk_scale.shape().at(1)) + << "topk_ids and topk_scale must share the same num_topk"; + TVM_FFI_ICHECK(x.shape().at(1) == gate_up_weight.shape().at(2)) + << "x and weight must share the same k"; + TVM_FFI_ICHECK(gate_up_weight.shape().at(0) == down_weight.shape().at(0)) + << "gate_up_weight and down_weight must share the same num_expert"; const void *shared_output_ptr = nullptr; if (shared_output.has_value()) { - const auto shared_output_tensor = shared_output.value(); - TORCH_CHECK(shared_output_tensor.device().is_cuda(), "shared_output tensor must be cuda"); - TORCH_CHECK(shared_output_tensor.is_contiguous(), "shared_output tensor must be contiguous"); - TORCH_CHECK(shared_output_tensor.dtype() == torch::kBFloat16, - "shared_output tensor dtype must be bfloat16"); - TORCH_CHECK( - shared_output_tensor.size(0) == x.size(0) && shared_output_tensor.size(1) == x.size(1), - "shared_output tensor shape must be same as x tensor"); - shared_output_ptr = shared_output_tensor.const_data_ptr(); + const auto &shared_output_tensor = shared_output.value(); + TVM_FFI_CHECK_CUDA(shared_output_tensor); + TVM_FFI_CHECK_CONTIGUOUS(shared_output_tensor); + shared_output_ptr = shared_output_tensor.data_ptr(); } - int num_seq = x.size(0); - int hidden_size = x.size(1); - int num_expert = gate_up_weight.size(0); - int intermediate_size = gate_up_weight.size(1); - int num_topk = topk_ids.size(1); - TORCH_CHECK(num_topk <= 128, "num_topk must less than or equal to 128"); - - auto options = x.options(); - torch::Tensor y = torch::empty({num_seq, hidden_size}, options.dtype(torch::kBFloat16)); - - torch::Tensor gate_up_input = torch::empty({num_seq * num_topk, hidden_size}, options); - torch::Tensor gate_up_output = - torch::empty({num_seq * num_topk, intermediate_size}, options.dtype(torch::kBFloat16)); - torch::Tensor gate_up_tmas = torch::empty({num_expert * 2, 128}, options.dtype(torch::kInt8)); - torch::Tensor down_input = torch::empty({num_seq * num_topk, intermediate_size / 2}, options); - torch::Tensor down_output = - torch::empty({num_seq * num_topk, hidden_size}, options.dtype(torch::kBFloat16)); - torch::Tensor down_tmas = torch::empty({num_expert * 2, 128}, options.dtype(torch::kInt8)); - - torch::Tensor topk_pos = torch::empty({num_seq, num_topk}, options.dtype(torch::kInt32)); - torch::Tensor seqlens = torch::zeros({num_expert}, options.dtype(torch::kInt32)); - torch::Tensor cu_seqlens = torch::empty({num_expert + 1}, options.dtype(torch::kInt32)); - torch::Tensor tiles = torch::empty({num_expert}, options.dtype(torch::kInt32)); - torch::Tensor cu_tiles = torch::empty({num_expert + 1}, options.dtype(torch::kInt32)); - - const auto *x_ptr = x.const_data_ptr(); - const auto *topk_ids_ptr = topk_ids.const_data_ptr(); - const auto *topk_scale_ptr = topk_scale.const_data_ptr(); - const auto *gate_up_weight_ptr = gate_up_weight.const_data_ptr(); - const auto *gate_up_scale_ptr = gate_up_scale.const_data_ptr(); - const auto *act_and_mul_scale_ptr = act_and_mul_scale.const_data_ptr(); - const auto *down_weight_ptr = down_weight.const_data_ptr(); - const auto *down_scale_ptr = down_scale.const_data_ptr(); - - auto *y_ptr = y.mutable_data_ptr(); - auto *topk_pos_ptr = topk_pos.mutable_data_ptr(); - auto *seqlens_ptr = seqlens.mutable_data_ptr(); - auto *cu_seqlens_ptr = cu_seqlens.mutable_data_ptr(); - auto *tiles_ptr = tiles.mutable_data_ptr(); - auto *cu_tiles_ptr = cu_tiles.mutable_data_ptr(); - auto *gate_up_input_ptr = gate_up_input.mutable_data_ptr(); - auto *gate_up_output_ptr = gate_up_output.mutable_data_ptr(); - auto *gate_up_tmas_ptr = gate_up_tmas.mutable_data_ptr(); - auto *down_input_ptr = down_input.mutable_data_ptr(); - auto *down_output_ptr = down_output.mutable_data_ptr(); - auto *down_tmas_ptr = down_tmas.mutable_data_ptr(); + int num_seq = x.shape().at(0); + int hidden_size = x.shape().at(1); + int num_expert = gate_up_weight.shape().at(0); + int intermediate_size = gate_up_weight.shape().at(1); + int num_topk = topk_ids.shape().at(1); + TVM_FFI_ICHECK(num_topk <= 128) << "num_topk must less than or equal to 128"; + + auto device = x.device(); + tvm::ffi::Tensor y = tvm_ffi_empty({num_seq, hidden_size}, dl_bfloat16, device); + + tvm::ffi::Tensor gate_up_input = + tvm_ffi_empty({num_seq * num_topk, hidden_size}, x.dtype(), device); + tvm::ffi::Tensor gate_up_output = + tvm_ffi_empty({num_seq * num_topk, intermediate_size}, dl_bfloat16, device); + tvm::ffi::Tensor gate_up_tmas = tvm_ffi_empty({num_expert * 2, 128}, dl_int8, device); + tvm::ffi::Tensor down_input = + tvm_ffi_empty({num_seq * num_topk, intermediate_size / 2}, x.dtype(), device); + tvm::ffi::Tensor down_output = + tvm_ffi_empty({num_seq * num_topk, hidden_size}, dl_bfloat16, device); + tvm::ffi::Tensor down_tmas = tvm_ffi_empty({num_expert * 2, 128}, dl_int8, device); + + tvm::ffi::Tensor topk_pos_tensor = tvm_ffi_empty({num_seq, num_topk}, dl_int32, device); + tvm::ffi::Tensor seqlens = tvm_ffi_zeros({num_expert}, dl_int32, device); + tvm::ffi::Tensor cu_seqlens = tvm_ffi_empty({num_expert + 1}, dl_int32, device); + tvm::ffi::Tensor tiles = tvm_ffi_empty({num_expert}, dl_int32, device); + tvm::ffi::Tensor cu_tiles = tvm_ffi_empty({num_expert + 1}, dl_int32, device); + + const auto *x_ptr = x.data_ptr(); + const auto *topk_ids_ptr = topk_ids.data_ptr(); + const auto *topk_scale_ptr = topk_scale.data_ptr(); + const auto *gate_up_weight_ptr = gate_up_weight.data_ptr(); + const auto *gate_up_scale_ptr = gate_up_scale.data_ptr(); + const auto *act_and_mul_scale_ptr = act_and_mul_scale.data_ptr(); + const auto *down_weight_ptr = down_weight.data_ptr(); + const auto *down_scale_ptr = down_scale.data_ptr(); + + auto *y_ptr = y.data_ptr(); + auto *topk_pos_ptr = topk_pos_tensor.data_ptr(); + auto *seqlens_ptr = seqlens.data_ptr(); + auto *cu_seqlens_ptr = cu_seqlens.data_ptr(); + auto *tiles_ptr = tiles.data_ptr(); + auto *cu_tiles_ptr = cu_tiles.data_ptr(); + auto *gate_up_input_ptr = gate_up_input.data_ptr(); + auto *gate_up_output_ptr = gate_up_output.data_ptr(); + auto *gate_up_tmas_ptr = gate_up_tmas.data_ptr(); + auto *down_input_ptr = down_input.data_ptr(); + auto *down_output_ptr = down_output.data_ptr(); + auto *down_tmas_ptr = down_tmas.data_ptr(); fuse_moe_pertensor_fp8_async( y_ptr, x_ptr, gate_up_input_ptr, gate_up_output_ptr, gate_up_weight_ptr, gate_up_scale_ptr, @@ -221,76 +226,66 @@ torch::Tensor fuse_moe_pertensor_fp8_entry( return y; } -torch::Tensor fuse_moe_blockwise_fp8_entry( - const torch::Tensor &x, const torch::Tensor &x_scale, const torch::Tensor &gate_up_weight, - const torch::Tensor &gate_up_weight_scale, const torch::Tensor &down_weight, - const torch::Tensor &down_weight_scale, const torch::Tensor &topk_ids, - const torch::Tensor &topk_scale, const std::optional &shared_output, - int64_t rank_ep, int64_t num_expert_total) { - auto stream = at::cuda::getCurrentCUDAStream(x.get_device()); - - TORCH_CHECK(x.device().is_cuda(), "x tensor must be cuda"); - TORCH_CHECK(x_scale.device().is_cuda(), "x_scale tensor must be cuda"); - TORCH_CHECK(gate_up_weight.device().is_cuda(), "gate_up_weight tensor must be cuda"); - TORCH_CHECK(gate_up_weight_scale.device().is_cuda(), "gate_up_weight_scale tensor must be cuda"); - TORCH_CHECK(down_weight.device().is_cuda(), "down_weight tensor must be cuda"); - TORCH_CHECK(down_weight_scale.device().is_cuda(), "down_weight_scale tensor must be cuda"); - TORCH_CHECK(topk_ids.device().is_cuda(), "topk_ids tensor must be cuda"); - TORCH_CHECK(topk_scale.device().is_cuda(), "topk_scale tensor must be cuda"); - - TORCH_CHECK(x.is_contiguous(), "x tensor must be contiguous"); - TORCH_CHECK(x_scale.is_contiguous(), "x_scale tensor must be contiguous"); - TORCH_CHECK(gate_up_weight.is_contiguous(), "gate_up_weight tensor must be contiguous"); - TORCH_CHECK(gate_up_weight_scale.is_contiguous(), - "gate_up_weight_scale tensor must be contiguous"); - TORCH_CHECK(down_weight.is_contiguous(), "down_weight tensor must be contiguous"); - TORCH_CHECK(down_weight_scale.is_contiguous(), "down_weight_scale tensor must be contiguous"); - TORCH_CHECK(topk_ids.is_contiguous(), "topk_ids tensor must be contiguous"); - TORCH_CHECK(topk_scale.is_contiguous(), "topk_scale tensor must be contiguous"); - - TORCH_CHECK(x.size(0) == topk_ids.size(0), "x and topk_ids must share the same num_tokens"); - TORCH_CHECK(topk_ids.size(0) == topk_scale.size(0), - "topk_ids and topk_scale must share the same num_tokens"); - TORCH_CHECK(topk_ids.size(1) == topk_scale.size(1), - "topk_ids and topk_scale must share the same num_topk"); - TORCH_CHECK(x.size(1) == gate_up_weight.size(2), "x and weight must share the same k"); - TORCH_CHECK(gate_up_weight.size(0) == down_weight.size(0), - "gate_up_weight and down_weight must share the same num_expert"); - TORCH_CHECK(x_scale.size(0) == x.size(0), "x_scale and x must share the same nun_tokens"); - TORCH_CHECK(x_scale.size(1) == x.size(1) / 128, "x_scale must be per 128 blockwise quant"); - TORCH_CHECK(gate_up_weight_scale.size(1) == gate_up_weight.size(1) / 128, - "gate_up_weight must be per 128 blockwise quant"); - TORCH_CHECK(gate_up_weight_scale.size(2) == (gate_up_weight.size(2) / 128 + 3) / 4 * 4, - "gate_up_weight must be per 128 blockwise quant and must be aligned to 4"); - TORCH_CHECK(down_weight_scale.size(1) == down_weight.size(1) / 128, - "down_weight must be per 128 blockwise quant"); - TORCH_CHECK(down_weight_scale.size(2) == (down_weight.size(2) / 128 + 3) / 4 * 4, - "down_weight must be per 128 blockwise quant and must be aligned to 4"); +tvm::ffi::Tensor fuse_moe_blockwise_fp8_entry( + const tvm::ffi::TensorView &x, const tvm::ffi::TensorView &x_scale, + const tvm::ffi::TensorView &gate_up_weight, + const tvm::ffi::TensorView &gate_up_weight_scale, + const tvm::ffi::TensorView &down_weight, const tvm::ffi::TensorView &down_weight_scale, + const tvm::ffi::TensorView &topk_ids, const tvm::ffi::TensorView &topk_scale, + tvm::ffi::Optional shared_output, int64_t rank_ep, + int64_t num_expert_total) { + auto stream = TVM_FFI_GET_CUDA_STREAM(x); + + TVM_FFI_CHECK_CUDA(x); + TVM_FFI_CHECK_CUDA(x_scale); + TVM_FFI_CHECK_CUDA(gate_up_weight); + TVM_FFI_CHECK_CUDA(gate_up_weight_scale); + TVM_FFI_CHECK_CUDA(down_weight); + TVM_FFI_CHECK_CUDA(down_weight_scale); + TVM_FFI_CHECK_CUDA(topk_ids); + TVM_FFI_CHECK_CUDA(topk_scale); + + TVM_FFI_CHECK_CONTIGUOUS(x); + TVM_FFI_CHECK_CONTIGUOUS(x_scale); + TVM_FFI_CHECK_CONTIGUOUS(gate_up_weight); + TVM_FFI_CHECK_CONTIGUOUS(gate_up_weight_scale); + TVM_FFI_CHECK_CONTIGUOUS(down_weight); + TVM_FFI_CHECK_CONTIGUOUS(down_weight_scale); + TVM_FFI_CHECK_CONTIGUOUS(topk_ids); + TVM_FFI_CHECK_CONTIGUOUS(topk_scale); + + TVM_FFI_ICHECK(x.shape().at(0) == topk_ids.shape().at(0)) + << "x and topk_ids must share the same num_tokens"; + TVM_FFI_ICHECK(topk_ids.shape().at(0) == topk_scale.shape().at(0)) + << "topk_ids and topk_scale must share the same num_tokens"; + TVM_FFI_ICHECK(topk_ids.shape().at(1) == topk_scale.shape().at(1)) + << "topk_ids and topk_scale must share the same num_topk"; + TVM_FFI_ICHECK(x.shape().at(1) == gate_up_weight.shape().at(2)) + << "x and weight must share the same k"; + TVM_FFI_ICHECK(gate_up_weight.shape().at(0) == down_weight.shape().at(0)) + << "gate_up_weight and down_weight must share the same num_expert"; const void *shared_output_ptr = nullptr; if (shared_output.has_value()) { - const auto shared_output_tensor = shared_output.value(); - TORCH_CHECK(shared_output_tensor.device().is_cuda(), "shared_output tensor must be cuda"); - TORCH_CHECK(shared_output_tensor.is_contiguous(), "shared_output tensor must be contiguous"); - TORCH_CHECK(shared_output_tensor.dtype() == torch::kBFloat16, - "shared_output tensor dtype must be bfloat16"); - TORCH_CHECK( - shared_output_tensor.size(0) == x.size(0) && shared_output_tensor.size(1) == x.size(1), - "shared_output tensor shape must be same as x tensor"); - shared_output_ptr = shared_output_tensor.const_data_ptr(); + const auto &shared_output_tensor = shared_output.value(); + TVM_FFI_CHECK_CUDA(shared_output_tensor); + TVM_FFI_CHECK_CONTIGUOUS(shared_output_tensor); + shared_output_ptr = shared_output_tensor.data_ptr(); } - int num_tokens = x.size(0); - int hidden_size = x.size(1); - int num_experts = gate_up_weight.size(0); - int intermediate_size = gate_up_weight.size(1); - int num_topk = topk_ids.size(1); + int num_tokens = x.shape().at(0); + int hidden_size = x.shape().at(1); + int num_experts = gate_up_weight.shape().at(0); + int intermediate_size = gate_up_weight.shape().at(1); + int num_topk = topk_ids.shape().at(1); int num_tokens_per_group_avg = num_tokens * num_topk / num_expert_total; int aligned_size = 0; - int gate_up_weight_scale_lastdim_pad4 = gate_up_weight_scale.size(-1); - int down_weight_scale_lastdim_pad4 = down_weight_scale.size(-1); + int gate_up_weight_scale_lastdim_pad4 = gate_up_weight_scale.shape().at( + gate_up_weight_scale.ndim() - 1); + int down_weight_scale_lastdim_pad4 = down_weight_scale.shape().at( + down_weight_scale.ndim() - 1); - TORCH_CHECK(num_topk <= 128, "num_topk must less than or equal to 128"); + TVM_FFI_ICHECK(num_topk <= 128) << "num_topk must less than or equal to 128"; if (num_tokens_per_group_avg <= 16) { aligned_size = 16; @@ -305,52 +300,53 @@ torch::Tensor fuse_moe_blockwise_fp8_entry( (num_tokens * num_topk + num_expert_total * aligned_size + aligned_size - 1) / aligned_size * aligned_size; - auto options = x.options(); - torch::Tensor y = torch::empty({num_tokens, hidden_size}, options.dtype(torch::kBFloat16)); - torch::Tensor gate_up_input = - torch::empty({num_tokens * num_topk, hidden_size}, options.dtype(torch::kFloat8_e4m3fn)); - torch::Tensor gate_up_input_scale = - torch::empty({x_scale.size(1), num_padded_tokens}, options.dtype(torch::kFloat32)); - torch::Tensor gate_up_output = - torch::empty({num_tokens * num_topk, intermediate_size}, options.dtype(torch::kBFloat16)); - torch::Tensor gate_up_tmas = torch::empty({num_experts * 2, 128}, options.dtype(torch::kInt8)); - torch::Tensor down_input = torch::empty({num_tokens * num_topk, intermediate_size / 2}, - options.dtype(torch::kFloat8_e4m3fn)); - torch::Tensor down_input_scale = torch::empty({intermediate_size / 2 / 128, num_padded_tokens}, - options.dtype(torch::kFloat32)); - torch::Tensor down_output = - torch::empty({num_tokens * num_topk, hidden_size}, options.dtype(torch::kBFloat16)); - torch::Tensor down_tmas = torch::empty({num_experts * 2, 128}, options.dtype(torch::kInt8)); - torch::Tensor topk_pos = torch::empty({num_tokens, num_topk}, options.dtype(torch::kInt32)); - torch::Tensor num_tokens_per_group = torch::zeros({num_experts}, options.dtype(torch::kInt32)); - torch::Tensor cu_num_tokens_per_group = - torch::empty({num_experts + 1}, options.dtype(torch::kInt32)); - torch::Tensor tiles = torch::empty({num_experts}, options.dtype(torch::kInt32)); - torch::Tensor cu_tiles = torch::empty({num_experts + 1}, options.dtype(torch::kInt32)); - - const auto *x_ptr = x.const_data_ptr(); - const auto *x_scale_ptr = x_scale.const_data_ptr(); - const auto *topk_ids_ptr = topk_ids.const_data_ptr(); - const auto *topk_scale_ptr = topk_scale.const_data_ptr(); - const auto *gate_up_weight_ptr = gate_up_weight.const_data_ptr(); - const auto *gate_up_weight_scale_ptr = gate_up_weight_scale.const_data_ptr(); - const auto *down_weight_ptr = down_weight.const_data_ptr(); - const auto *down_weight_scale_ptr = down_weight_scale.const_data_ptr(); - - auto *y_ptr = y.mutable_data_ptr(); - auto *topk_pos_ptr = topk_pos.mutable_data_ptr(); - auto *num_tokens_per_group_ptr = num_tokens_per_group.mutable_data_ptr(); - auto *cu_num_tokens_per_group_ptr = cu_num_tokens_per_group.mutable_data_ptr(); - auto *tiles_ptr = tiles.mutable_data_ptr(); - auto *cu_tiles_ptr = cu_tiles.mutable_data_ptr(); - auto *gate_up_input_ptr = gate_up_input.mutable_data_ptr(); - auto *gate_up_input_scale_ptr = gate_up_input_scale.mutable_data_ptr(); - auto *gate_up_output_ptr = gate_up_output.mutable_data_ptr(); - auto *gate_up_tmas_ptr = gate_up_tmas.mutable_data_ptr(); - auto *down_input_ptr = down_input.mutable_data_ptr(); - auto *down_input_scale_ptr = down_input_scale.mutable_data_ptr(); - auto *down_output_ptr = down_output.mutable_data_ptr(); - auto *down_tmas_ptr = down_tmas.mutable_data_ptr(); + auto device = x.device(); + tvm::ffi::Tensor y = tvm_ffi_empty({num_tokens, hidden_size}, dl_bfloat16, device); + tvm::ffi::Tensor gate_up_input = + tvm_ffi_empty({num_tokens * num_topk, hidden_size}, dl_float8_e4m3, device); + tvm::ffi::Tensor gate_up_input_scale = + tvm_ffi_empty({x_scale.shape().at(1), num_padded_tokens}, dl_float32, device); + tvm::ffi::Tensor gate_up_output = + tvm_ffi_empty({num_tokens * num_topk, intermediate_size}, dl_bfloat16, device); + tvm::ffi::Tensor gate_up_tmas = tvm_ffi_empty({num_experts * 2, 128}, dl_int8, device); + tvm::ffi::Tensor down_input = + tvm_ffi_empty({num_tokens * num_topk, intermediate_size / 2}, dl_float8_e4m3, device); + tvm::ffi::Tensor down_input_scale = + tvm_ffi_empty({intermediate_size / 2 / 128, num_padded_tokens}, dl_float32, device); + tvm::ffi::Tensor down_output = + tvm_ffi_empty({num_tokens * num_topk, hidden_size}, dl_bfloat16, device); + tvm::ffi::Tensor down_tmas = tvm_ffi_empty({num_experts * 2, 128}, dl_int8, device); + tvm::ffi::Tensor topk_pos_tensor = + tvm_ffi_empty({num_tokens, num_topk}, dl_int32, device); + tvm::ffi::Tensor num_tokens_per_group = tvm_ffi_zeros({num_experts}, dl_int32, device); + tvm::ffi::Tensor cu_num_tokens_per_group = + tvm_ffi_empty({num_experts + 1}, dl_int32, device); + tvm::ffi::Tensor tiles = tvm_ffi_empty({num_experts}, dl_int32, device); + tvm::ffi::Tensor cu_tiles = tvm_ffi_empty({num_experts + 1}, dl_int32, device); + + const auto *x_ptr = x.data_ptr(); + const auto *x_scale_ptr = x_scale.data_ptr(); + const auto *topk_ids_ptr = topk_ids.data_ptr(); + const auto *topk_scale_ptr = topk_scale.data_ptr(); + const auto *gate_up_weight_ptr = gate_up_weight.data_ptr(); + const auto *gate_up_weight_scale_ptr = gate_up_weight_scale.data_ptr(); + const auto *down_weight_ptr = down_weight.data_ptr(); + const auto *down_weight_scale_ptr = down_weight_scale.data_ptr(); + + auto *y_ptr = y.data_ptr(); + auto *topk_pos_ptr = topk_pos_tensor.data_ptr(); + auto *num_tokens_per_group_ptr = num_tokens_per_group.data_ptr(); + auto *cu_num_tokens_per_group_ptr = cu_num_tokens_per_group.data_ptr(); + auto *tiles_ptr = tiles.data_ptr(); + auto *cu_tiles_ptr = cu_tiles.data_ptr(); + auto *gate_up_input_ptr = gate_up_input.data_ptr(); + auto *gate_up_input_scale_ptr = gate_up_input_scale.data_ptr(); + auto *gate_up_output_ptr = gate_up_output.data_ptr(); + auto *gate_up_tmas_ptr = gate_up_tmas.data_ptr(); + auto *down_input_ptr = down_input.data_ptr(); + auto *down_input_scale_ptr = down_input_scale.data_ptr(); + auto *down_output_ptr = down_output.data_ptr(); + auto *down_tmas_ptr = down_tmas.data_ptr(); fuse_moe_blockwise_fp8_async( y_ptr, x_ptr, x_scale_ptr, gate_up_input_ptr, gate_up_input_scale_ptr, gate_up_output_ptr, @@ -366,31 +362,9 @@ torch::Tensor fuse_moe_blockwise_fp8_entry( } // namespace fuse_moe } // namespace hpc -TORCH_LIBRARY_FRAGMENT(hpc, m) { - m.def( - "count_and_gather(Tensor x, Tensor topk_ids, int num_expert, int rank_ep, int " - "intermediate_size, int num_seq_per_group_avg" - ") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); - m.impl("count_and_gather", torch::kCUDA, &hpc::fuse_moe::count_and_gather_entry); - - m.def( - "reduce(Tensor x, Tensor topk_pos, Tensor topk_scale, Tensor ? shared_output" - ") -> (Tensor)"); - m.impl("reduce", torch::kCUDA, &hpc::fuse_moe::reduce_entry); - - m.def( - "fuse_moe_pertensor_fp8(Tensor x, Tensor gate_up_weight, Tensor down_weight, Tensor " - "gate_up_scale, " - "Tensor down_scale, Tensor act_and_mul_scale, Tensor topk_ids, Tensor topk_scale, Tensor ? " - "shared_output, " - "int rank_ep, int num_expert_total, bool use_bf16_mul) -> (Tensor)"); - m.impl("fuse_moe_pertensor_fp8", torch::kCUDA, &hpc::fuse_moe::fuse_moe_pertensor_fp8_entry); - - m.def( - "fuse_moe_blockwise_fp8(Tensor x, Tensor x_scale, Tensor gate_up_weight, Tensor " - "gate_up_weight_scale, " - "Tensor down_weight, Tensor down_weight_scale, Tensor topk_ids, Tensor topk_scale, Tensor ? " - "shared_output, " - "int rank_ep, int num_expert_total) -> (Tensor)"); - m.impl("fuse_moe_blockwise_fp8", torch::kCUDA, &hpc::fuse_moe::fuse_moe_blockwise_fp8_entry); -} +TVM_FFI_DLL_EXPORT_TYPED_FUNC(count_and_gather, hpc::fuse_moe::count_and_gather_entry); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(reduce, hpc::fuse_moe::reduce_entry); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(fuse_moe_pertensor_fp8, + hpc::fuse_moe::fuse_moe_pertensor_fp8_entry); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(fuse_moe_blockwise_fp8, + hpc::fuse_moe::fuse_moe_blockwise_fp8_entry); diff --git a/src/group_gemm/entry.cc b/src/group_gemm/entry.cc index 5c74dff..6a4bc4a 100644 --- a/src/group_gemm/entry.cc +++ b/src/group_gemm/entry.cc @@ -1,70 +1,72 @@ // Copyright (C) 2026 Tencent. -#include #include -#include -#include -#include +#include +#include +#include + +#include "tvm_ffi_utils.h" #include "src/group_gemm/group_gemm.h" namespace hpc { namespace group_gemm { -torch::Tensor group_gemm_pertensor_fp8_entry(const torch::Tensor &x, const torch::Tensor &weight, - const torch::Tensor &seqlens, - const torch::Tensor &cu_seqlens, - const torch::Tensor &y_scale, - const int64_t num_seq_per_group_avg, - std::optional output, - std::optional tma_desc) { - auto stream = at::cuda::getCurrentCUDAStream(x.get_device()); - TORCH_CHECK(x.device().is_cuda(), "x tensor must be cuda"); - TORCH_CHECK(weight.device().is_cuda(), "weight tensor must be cuda"); - TORCH_CHECK(seqlens.device().is_cuda(), "seqlens tensor must be cuda"); - TORCH_CHECK(cu_seqlens.device().is_cuda(), "cu_seqlens tensor must be cuda"); - TORCH_CHECK(x.is_contiguous(), "x tensor a must be contiguous"); - TORCH_CHECK(weight.is_contiguous(), "weight tensor a must be contiguous"); - TORCH_CHECK(seqlens.size(0) == weight.size(0), - "seqlens and weight must share the same num_group"); - TORCH_CHECK(x.size(1) == weight.size(2), "x and weight must share the same k"); - - int m = x.size(0); - int k = x.size(1); - int n = weight.size(1); - int num_group = seqlens.size(0); - - auto options = x.options(); - torch::Tensor y; +tvm::ffi::Tensor group_gemm_pertensor_fp8_entry( + const tvm::ffi::TensorView &x, const tvm::ffi::TensorView &weight, + const tvm::ffi::TensorView &seqlens, const tvm::ffi::TensorView &cu_seqlens, + const tvm::ffi::TensorView &y_scale, int64_t num_seq_per_group_avg, + tvm::ffi::Optional output, + tvm::ffi::Optional tma_desc) { + auto stream = TVM_FFI_GET_CUDA_STREAM(x); + TVM_FFI_CHECK_CUDA(x); + TVM_FFI_CHECK_CUDA(weight); + TVM_FFI_CHECK_CUDA(seqlens); + TVM_FFI_CHECK_CUDA(cu_seqlens); + TVM_FFI_CHECK_CONTIGUOUS(x); + TVM_FFI_CHECK_CONTIGUOUS(weight); + TVM_FFI_ICHECK(seqlens.shape().at(0) == weight.shape().at(0)) + << "seqlens and weight must share the same num_group"; + TVM_FFI_ICHECK(x.shape().at(1) == weight.shape().at(2)) + << "x and weight must share the same k"; + + int m = x.shape().at(0); + int k = x.shape().at(1); + int n = weight.shape().at(1); + int num_group = seqlens.shape().at(0); + + auto device = x.device(); + + tvm::ffi::Tensor y; if (output.has_value()) { - y = output.value(); + y = tvm::ffi::Tensor(output.value()); } else { - y = torch::empty({m, n}, options.dtype(torch::kBFloat16)); + y = tvm_ffi_empty({m, n}, dl_bfloat16, device); } - torch::Tensor tmas; + tvm::ffi::Tensor tmas; bool update_tma = true; if (tma_desc.has_value()) { - tmas = tma_desc.value(); + tmas = tvm::ffi::Tensor(tma_desc.value()); update_tma = false; } else { - tmas = torch::empty({num_group * 2, 128}, options); + tmas = tvm_ffi_empty({num_group * 2, 128}, x.dtype(), device); } - torch::Tensor tiles = torch::empty({num_group}, options.dtype(torch::kInt32)); - torch::Tensor cu_tiles = torch::empty({num_group + 1}, options.dtype(torch::kInt32)); + tvm::ffi::Tensor tiles = tvm_ffi_empty({num_group}, dl_int32, device); + tvm::ffi::Tensor cu_tiles = tvm_ffi_empty({num_group + 1}, dl_int32, device); - const auto *x_ptr = x.const_data_ptr(); - const auto *weight_ptr = weight.const_data_ptr(); - const auto *seqlens_ptr = seqlens.const_data_ptr(); - const auto *cu_seqlens_ptr = cu_seqlens.const_data_ptr(); - const auto *yscale_ptr = y_scale.const_data_ptr(); - auto *tmas_ptr = tmas.mutable_data_ptr(); - auto *y_ptr = y.mutable_data_ptr(); + const auto *x_ptr = x.data_ptr(); + const auto *weight_ptr = weight.data_ptr(); + const auto *seqlens_ptr = seqlens.data_ptr(); + const auto *cu_seqlens_ptr = cu_seqlens.data_ptr(); + const auto *yscale_ptr = y_scale.data_ptr(); + auto *tmas_ptr = tmas.data_ptr(); + auto *y_ptr = y.data_ptr(); - auto *tiles_ptr = tiles.mutable_data_ptr(); - auto *cu_tiles_ptr = cu_tiles.mutable_data_ptr(); + auto *tiles_ptr = tiles.data_ptr(); + auto *cu_tiles_ptr = cu_tiles.data_ptr(); group_gemm_pertensor_fp8_async(y_ptr, x_ptr, weight_ptr, seqlens_ptr, cu_seqlens_ptr, yscale_ptr, tmas_ptr, tiles_ptr, cu_tiles_ptr, num_group, m, n, k, @@ -73,61 +75,64 @@ torch::Tensor group_gemm_pertensor_fp8_entry(const torch::Tensor &x, const torch return y; } -torch::Tensor group_gemm_blockwise_fp8_entry( - const torch::Tensor &x, const torch::Tensor &weight, const torch::Tensor &seqlens, - const torch::Tensor &cu_seqlens, const torch::Tensor &x_scale, const torch::Tensor &w_scale, - const int64_t num_seq_per_group_avg, std::optional output, - std::optional tma_desc) { - auto stream = at::cuda::getCurrentCUDAStream(x.get_device()); - TORCH_CHECK(x.device().is_cuda(), "x tensor must be cuda"); - TORCH_CHECK(weight.device().is_cuda(), "weight tensor must be cuda"); - TORCH_CHECK(seqlens.device().is_cuda(), "seqlens tensor must be cuda"); - TORCH_CHECK(cu_seqlens.device().is_cuda(), "cu_seqlens tensor must be cuda"); - TORCH_CHECK(x.is_contiguous(), "x tensor a must be contiguous"); - TORCH_CHECK(weight.is_contiguous(), "weight tensor a must be contiguous"); - TORCH_CHECK(seqlens.size(0) == weight.size(0), - "seqlens and weight must share the same num_group"); - TORCH_CHECK(x.size(1) == weight.size(2), "x and weight must share the same k"); - TORCH_CHECK(w_scale.size(2) % 4 == 0, "w_scale must be multiple of 4"); - - int m = x.size(0); - int k = x.size(1); - int n = weight.size(1); - int m_pad = x_scale.size(1); - int num_block_k_pad4 = w_scale.size(2); - int num_group = seqlens.size(0); - - auto options = x.options(); - torch::Tensor y; +tvm::ffi::Tensor group_gemm_blockwise_fp8_entry( + const tvm::ffi::TensorView &x, const tvm::ffi::TensorView &weight, + const tvm::ffi::TensorView &seqlens, const tvm::ffi::TensorView &cu_seqlens, + const tvm::ffi::TensorView &x_scale, const tvm::ffi::TensorView &w_scale, + int64_t num_seq_per_group_avg, tvm::ffi::Optional output, + tvm::ffi::Optional tma_desc) { + auto stream = TVM_FFI_GET_CUDA_STREAM(x); + TVM_FFI_CHECK_CUDA(x); + TVM_FFI_CHECK_CUDA(weight); + TVM_FFI_CHECK_CUDA(seqlens); + TVM_FFI_CHECK_CUDA(cu_seqlens); + TVM_FFI_CHECK_CONTIGUOUS(x); + TVM_FFI_CHECK_CONTIGUOUS(weight); + TVM_FFI_ICHECK(seqlens.shape().at(0) == weight.shape().at(0)) + << "seqlens and weight must share the same num_group"; + TVM_FFI_ICHECK(x.shape().at(1) == weight.shape().at(2)) + << "x and weight must share the same k"; + TVM_FFI_ICHECK(w_scale.shape().at(2) % 4 == 0) << "w_scale must be multiple of 4"; + + int m = x.shape().at(0); + int k = x.shape().at(1); + int n = weight.shape().at(1); + int m_pad = x_scale.shape().at(1); + int num_block_k_pad4 = w_scale.shape().at(2); + int num_group = seqlens.shape().at(0); + + auto device = x.device(); + + tvm::ffi::Tensor y; if (output.has_value()) { - y = output.value(); + y = tvm::ffi::Tensor(output.value()); } else { - y = torch::empty({m, n}, options.dtype(torch::kBFloat16)); + y = tvm_ffi_empty({m, n}, dl_bfloat16, device); } - torch::Tensor tmas; + tvm::ffi::Tensor tmas; bool update_tma = true; if (tma_desc.has_value()) { - tmas = tma_desc.value(); + tmas = tvm::ffi::Tensor(tma_desc.value()); update_tma = false; } else { - tmas = torch::empty({num_group * 2, 128}, options); + tmas = tvm_ffi_empty({num_group * 2, 128}, x.dtype(), device); } - torch::Tensor tiles = torch::empty({num_group}, options.dtype(torch::kInt32)); - torch::Tensor cu_tiles = torch::empty({num_group + 1}, options.dtype(torch::kInt32)); + tvm::ffi::Tensor tiles = tvm_ffi_empty({num_group}, dl_int32, device); + tvm::ffi::Tensor cu_tiles = tvm_ffi_empty({num_group + 1}, dl_int32, device); - const auto *x_ptr = x.const_data_ptr(); - const auto *weight_ptr = weight.const_data_ptr(); - const auto *seqlens_ptr = seqlens.const_data_ptr(); - const auto *cu_seqlens_ptr = cu_seqlens.const_data_ptr(); - const auto *xscale_ptr = x_scale.const_data_ptr(); - const auto *wscale_ptr = w_scale.const_data_ptr(); - auto *tmas_ptr = tmas.mutable_data_ptr(); - auto *y_ptr = y.mutable_data_ptr(); + const auto *x_ptr = x.data_ptr(); + const auto *weight_ptr = weight.data_ptr(); + const auto *seqlens_ptr = seqlens.data_ptr(); + const auto *cu_seqlens_ptr = cu_seqlens.data_ptr(); + const auto *xscale_ptr = x_scale.data_ptr(); + const auto *wscale_ptr = w_scale.data_ptr(); + auto *tmas_ptr = tmas.data_ptr(); + auto *y_ptr = y.data_ptr(); - auto *tiles_ptr = tiles.mutable_data_ptr(); - auto *cu_tiles_ptr = cu_tiles.mutable_data_ptr(); + auto *tiles_ptr = tiles.data_ptr(); + auto *cu_tiles_ptr = cu_tiles.data_ptr(); group_gemm_blockwise_fp8_async(y_ptr, x_ptr, weight_ptr, seqlens_ptr, cu_seqlens_ptr, xscale_ptr, wscale_ptr, tmas_ptr, tiles_ptr, cu_tiles_ptr, num_group, m, n, k, @@ -137,23 +142,25 @@ torch::Tensor group_gemm_blockwise_fp8_entry( return y; } -torch::Tensor reformat_x_scale_entry(const torch::Tensor &x_scale, const torch::Tensor &seqlens, - const torch::Tensor &cu_seqlens, - std::optional out_x_scale, - const int64_t num_seq_per_group_avg) { - auto stream = at::cuda::getCurrentCUDAStream(x_scale.get_device()); - TORCH_CHECK(x_scale.device().is_cuda(), "x_scale tensor must be cuda"); - TORCH_CHECK(seqlens.device().is_cuda(), "seqlens tensor must be cuda"); - TORCH_CHECK(cu_seqlens.device().is_cuda(), "cu_seqlens tensor must be cuda"); - TORCH_CHECK(x_scale.is_contiguous(), "x_scale tensor a must be contiguous"); - TORCH_CHECK(seqlens.is_contiguous(), "seqlens tensor a must be contiguous"); - TORCH_CHECK(cu_seqlens.is_contiguous(), "cu_seqlens tensor a must be contiguous"); - - int m = x_scale.size(0); - int n = x_scale.size(1); - TORCH_CHECK(n == 16 || n == 32, "n must be 16 or 32(for dsv4 group gemm k=2048 or 4096)"); - - int num_group = seqlens.size(0); +tvm::ffi::Tensor reformat_x_scale_entry(const tvm::ffi::TensorView &x_scale, + const tvm::ffi::TensorView &seqlens, + const tvm::ffi::TensorView &cu_seqlens, + tvm::ffi::Optional out_x_scale, + int64_t num_seq_per_group_avg) { + auto stream = TVM_FFI_GET_CUDA_STREAM(x_scale); + TVM_FFI_CHECK_CUDA(x_scale); + TVM_FFI_CHECK_CUDA(seqlens); + TVM_FFI_CHECK_CUDA(cu_seqlens); + TVM_FFI_CHECK_CONTIGUOUS(x_scale); + TVM_FFI_CHECK_CONTIGUOUS(seqlens); + TVM_FFI_CHECK_CONTIGUOUS(cu_seqlens); + + int m = x_scale.shape().at(0); + int n = x_scale.shape().at(1); + TVM_FFI_ICHECK(n == 16 || n == 32) + << "n must be 16 or 32(for dsv4 group gemm k=2048 or 4096)"; + + int num_group = seqlens.shape().at(0); int tilem = 0; // careful!!! here logit must be corresponds with group_gemm_blockwise_fp8_async if (num_seq_per_group_avg <= 16) { @@ -164,21 +171,23 @@ torch::Tensor reformat_x_scale_entry(const torch::Tensor &x_scale, const torch:: tilem = 64; } int num_seq_pad_per_group = m / num_group; - TORCH_CHECK(num_seq_pad_per_group % tilem == 0, - "The sparse pad length of x_scale for each group must be aligned to multiple of " - "16/32/64 according to num_seq_per_group_avg"); + TVM_FFI_ICHECK(num_seq_pad_per_group % tilem == 0) + << "The sparse pad length of x_scale for each group must be aligned to multiple of " + "16/32/64 according to num_seq_per_group_avg"; + + auto device = x_scale.device(); - torch::Tensor output; + tvm::ffi::Tensor output; if (out_x_scale.has_value()) { - output = out_x_scale.value(); + output = tvm::ffi::Tensor(out_x_scale.value()); } else { - output = torch::empty({n, m}, x_scale.options()); + output = tvm_ffi_empty({n, m}, dl_float32, device); } - const auto *xscale_ptr = x_scale.const_data_ptr(); - const auto *seqlens_ptr = seqlens.const_data_ptr(); - const auto *cu_seqlens_ptr = cu_seqlens.const_data_ptr(); - auto *output_ptr = output.mutable_data_ptr(); + const auto *xscale_ptr = x_scale.data_ptr(); + const auto *seqlens_ptr = seqlens.data_ptr(); + const auto *cu_seqlens_ptr = cu_seqlens.data_ptr(); + auto *output_ptr = output.data_ptr(); reformat_x_scale_async(output_ptr, xscale_ptr, seqlens_ptr, cu_seqlens_ptr, num_group, m, n, tilem, stream); @@ -189,23 +198,9 @@ torch::Tensor reformat_x_scale_entry(const torch::Tensor &x_scale, const torch:: } // namespace group_gemm } // namespace hpc -TORCH_LIBRARY_FRAGMENT(hpc, m) { - m.def( - "group_gemm_pertensor_fp8(Tensor x, Tensor weight, Tensor seqlens, Tensor cu_seqlens, Tensor " - "y_scale, " - "int num_seq_per_group_avg, Tensor? output, Tensor? tma_desc) -> (Tensor)"); - m.impl("group_gemm_pertensor_fp8", torch::kCUDA, - &hpc::group_gemm::group_gemm_pertensor_fp8_entry); - - m.def( - "group_gemm_blockwise_fp8(Tensor x, Tensor weight, Tensor seqlens, Tensor cu_seqlens, Tensor " - "xscale, Tensor wscale," - "int num_seq_per_group_avg, Tensor? output, Tensor? tma_desc) -> (Tensor)"); - m.impl("group_gemm_blockwise_fp8", torch::kCUDA, - &hpc::group_gemm::group_gemm_blockwise_fp8_entry); - - m.def( - "reformat_x_scale(Tensor x_scale, Tensor seqlens, Tensor cu_seqlens, " - "Tensor? out_x_scale, int num_seq_per_group_avg) -> (Tensor)"); - m.impl("reformat_x_scale", torch::kCUDA, &hpc::group_gemm::reformat_x_scale_entry); -} +TVM_FFI_DLL_EXPORT_TYPED_FUNC(group_gemm_pertensor_fp8, + hpc::group_gemm::group_gemm_pertensor_fp8_entry); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(group_gemm_blockwise_fp8, + hpc::group_gemm::group_gemm_blockwise_fp8_entry); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(reformat_x_scale, + hpc::group_gemm::reformat_x_scale_entry); diff --git a/src/utils/include/tvm_ffi_utils.h b/src/utils/include/tvm_ffi_utils.h new file mode 100644 index 0000000..4276ff8 --- /dev/null +++ b/src/utils/include/tvm_ffi_utils.h @@ -0,0 +1,57 @@ +// Copyright (C) 2026 Tencent. + +#pragma once + +#include +#include +#include + +#include + +inline constexpr int64_t encode_dlpack_dtype(DLDataType dtype) { + return (dtype.code << 16) | (dtype.bits << 8) | dtype.lanes; +} + +constexpr DLDataType dl_float16 = DLDataType{kDLFloat, 16, 1}; +constexpr DLDataType dl_float32 = DLDataType{kDLFloat, 32, 1}; +constexpr DLDataType dl_float64 = DLDataType{kDLFloat, 64, 1}; +constexpr DLDataType dl_bfloat16 = DLDataType{kDLBfloat, 16, 1}; +constexpr DLDataType dl_int32 = DLDataType{kDLInt, 32, 1}; +constexpr DLDataType dl_int64 = DLDataType{kDLInt, 64, 1}; +constexpr DLDataType dl_uint8 = DLDataType{kDLUInt, 8, 1}; +constexpr DLDataType dl_int8 = DLDataType{kDLInt, 8, 1}; +constexpr DLDataType dl_float8_e4m3 = DLDataType{6 /*kDLFloat8_e4m3fn*/, 8, 1}; + +constexpr int64_t float16_code = encode_dlpack_dtype(dl_float16); +constexpr int64_t float32_code = encode_dlpack_dtype(dl_float32); +constexpr int64_t float64_code = encode_dlpack_dtype(dl_float64); +constexpr int64_t bfloat16_code = encode_dlpack_dtype(dl_bfloat16); +constexpr int64_t int32_code = encode_dlpack_dtype(dl_int32); + +#define TVM_FFI_GET_CUDA_STREAM(data) \ + static_cast(TVMFFIEnvGetStream(data.device().device_type, data.device().device_id)) + +#define CHECK_CUDA_SUCCESS(err) \ + do { \ + TVM_FFI_ICHECK(err == cudaSuccess) << "CUDA Failure: " << cudaGetErrorString(err); \ + } while (0) + +#define TVM_FFI_CHECK_CUDA(input) \ + TVM_FFI_ICHECK(input.device().device_type == kDLCUDA) << #input " tensor must be cuda" + +#define TVM_FFI_CHECK_CONTIGUOUS(input) \ + TVM_FFI_ICHECK(input.IsContiguous()) << #input " tensor must be contiguous" + +inline tvm::ffi::Tensor tvm_ffi_empty(std::vector shape, DLDataType dtype, + DLDevice device) { + tvm::ffi::ShapeView sv(shape.data(), shape.size()); + return tvm::ffi::Tensor::FromEnvAlloc(TVMFFIEnvTensorAlloc, sv, dtype, device); +} + +inline tvm::ffi::Tensor tvm_ffi_zeros(std::vector shape, DLDataType dtype, + DLDevice device) { + tvm::ffi::ShapeView sv(shape.data(), shape.size()); + auto t = tvm::ffi::Tensor::FromEnvAlloc(TVMFFIEnvTensorAlloc, sv, dtype, device); + cudaMemset(t.data_ptr(), 0, t.numel() * (dtype.bits / 8)); + return t; +} diff --git a/tests/test_act.py b/tests/test_act.py index 07e1d69..76c308e 100644 --- a/tests/test_act.py +++ b/tests/test_act.py @@ -2,7 +2,7 @@ import sys from pathlib import Path -sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) +sys.path.insert(0, str(Path(__file__).parent.parent)) from typing import Tuple diff --git a/tests/test_attention_decode_bf16.py b/tests/test_attention_decode_bf16.py index ecee194..e882be3 100644 --- a/tests/test_attention_decode_bf16.py +++ b/tests/test_attention_decode_bf16.py @@ -4,7 +4,7 @@ import pytest -sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) +sys.path.insert(0, str(Path(__file__).parent.parent)) import math diff --git a/tests/test_attention_decode_fp8.py b/tests/test_attention_decode_fp8.py index e2c57be..a7a8ab9 100644 --- a/tests/test_attention_decode_fp8.py +++ b/tests/test_attention_decode_fp8.py @@ -4,7 +4,7 @@ import pytest -sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) +sys.path.insert(0, str(Path(__file__).parent.parent)) import math diff --git a/tests/test_attention_prefill_bf16.py b/tests/test_attention_prefill_bf16.py index e1a464e..53d1324 100644 --- a/tests/test_attention_prefill_bf16.py +++ b/tests/test_attention_prefill_bf16.py @@ -4,7 +4,7 @@ import pytest -sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) +sys.path.insert(0, str(Path(__file__).parent.parent)) import math diff --git a/tests/test_attention_with_kvcache_prefill_bf16.py b/tests/test_attention_with_kvcache_prefill_bf16.py index 79dc689..a95a213 100644 --- a/tests/test_attention_with_kvcache_prefill_bf16.py +++ b/tests/test_attention_with_kvcache_prefill_bf16.py @@ -4,7 +4,7 @@ import pytest -sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) +sys.path.insert(0, str(Path(__file__).parent.parent)) import math diff --git a/tests/test_attention_with_kvcache_prefill_fp8.py b/tests/test_attention_with_kvcache_prefill_fp8.py index 9454208..9c3f71e 100644 --- a/tests/test_attention_with_kvcache_prefill_fp8.py +++ b/tests/test_attention_with_kvcache_prefill_fp8.py @@ -4,7 +4,7 @@ import pytest -sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) +sys.path.insert(0, str(Path(__file__).parent.parent)) import math diff --git a/tests/test_fuse_moe_blockwise.py b/tests/test_fuse_moe_blockwise.py index 4f3456c..2dbac84 100644 --- a/tests/test_fuse_moe_blockwise.py +++ b/tests/test_fuse_moe_blockwise.py @@ -2,7 +2,7 @@ import sys from pathlib import Path -sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) +sys.path.insert(0, str(Path(__file__).parent.parent)) import math from pathlib import Path diff --git a/tests/test_fuse_moe_pertensor.py b/tests/test_fuse_moe_pertensor.py index 3ed35c8..ef2bde0 100644 --- a/tests/test_fuse_moe_pertensor.py +++ b/tests/test_fuse_moe_pertensor.py @@ -2,7 +2,7 @@ import sys from pathlib import Path -sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) +sys.path.insert(0, str(Path(__file__).parent.parent)) import math from pathlib import Path diff --git a/tests/test_group_gemm_blockwise.py b/tests/test_group_gemm_blockwise.py index 3c3a0ed..057653d 100644 --- a/tests/test_group_gemm_blockwise.py +++ b/tests/test_group_gemm_blockwise.py @@ -2,7 +2,7 @@ import sys from pathlib import Path -sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) +sys.path.insert(0, str(Path(__file__).parent.parent)) import math diff --git a/tests/test_group_gemm_pertensor.py b/tests/test_group_gemm_pertensor.py index b2aa75e..f96e7a2 100644 --- a/tests/test_group_gemm_pertensor.py +++ b/tests/test_group_gemm_pertensor.py @@ -2,7 +2,7 @@ import sys from pathlib import Path -sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) +sys.path.insert(0, str(Path(__file__).parent.parent)) import math diff --git a/tests/test_version.py b/tests/test_version.py index cbf5334..8e6eeaa 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -2,7 +2,7 @@ import os from pathlib import Path -sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) +sys.path.insert(0, str(Path(__file__).parent.parent)) import hpc diff --git a/version_with_meta.py b/version_with_meta.py new file mode 100644 index 0000000..73bcd87 --- /dev/null +++ b/version_with_meta.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import os +import subprocess + +__all__ = ["dynamic_metadata"] + +base_version = "0.0.1" + + +def dynamic_metadata( + field: str, + settings: dict[str, object] | None = None, +) -> str: + try: + assert field == "version" + + patched_version = base_version + + # Try to add git hash + try: + git_hash = subprocess.check_output( + ["git", "rev-parse", "--short=7", "HEAD"], + stderr=subprocess.DEVNULL, + text=True, + ).strip() + patched_version += f".dev0+g{git_hash}" + except Exception: + pass + + # VERSION_SUFFIX='+cu128' + if version_ext := os.environ.get("VERSION_SUFFIX"): + patched_version = base_version + version_ext + elif cuda_version := os.environ.get("CUDA_VERSION"): + major, minor, *_ = cuda_version.split(".") + backend = f"+cu{major}{minor}" + patched_version = base_version + backend + + return patched_version + except Exception: + return base_version