From fb3da74054cb9355399278dfc2578c02fc18560e Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Wed, 28 Jan 2026 13:46:36 +0800 Subject: [PATCH 1/2] Support ctypes.CDLL patching for lib function name conversion Signed-off-by: Xiaodong Ye --- README.md | 22 ++++ README_CN.md | 22 ++++ src/torchada/__init__.py | 11 ++ src/torchada/_patch.py | 151 +++++++++++++++++++++ src/torchada/_runtime.py | 132 +++++++++++++++++++ tests/test_mappings.py | 275 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 613 insertions(+) create mode 100644 src/torchada/_runtime.py diff --git a/README.md b/README.md index 3bf341e..5a8a723 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ That's it! All `torch.cuda.*` APIs are automatically redirected to `torch.musa.* | Distributed | `dist.init_process_group(backend='nccl')` → uses MCCL | | torch.compile | `torch.compile(model)` with all backends | | C++ Extensions | `CUDAExtension`, `BuildExtension`, `load()` | +| ctypes Libraries | `ctypes.CDLL` with CUDA function names → MUSA equivalents | ## Examples @@ -146,6 +147,21 @@ with torch.profiler.profile( model(x) ``` +### ctypes Library Loading + +```python +import torchada +import ctypes + +# Load MUSA runtime library with CUDA function names +lib = ctypes.CDLL("libmusart.so") +func = lib.cudaMalloc # Automatically translates to musaMalloc + +# Works with MCCL too +nccl_lib = ctypes.CDLL("libmccl.so") +func = nccl_lib.ncclAllReduce # Automatically translates to mcclAllReduce +``` + ## Platform Detection ```python @@ -192,9 +208,15 @@ if torchada.is_gpu_device(device): # Works on both CUDA and MUSA | `is_cuda_platform()` | Returns True if running on CUDA | | `is_gpu_device(device)` | Returns True if device is CUDA or MUSA | | `CUDA_HOME` | Path to CUDA/MUSA installation | +| `cuda_to_musa_name(name)` | Convert `cudaXxx` → `musaXxx` | +| `nccl_to_mccl_name(name)` | Convert `ncclXxx` → `mcclXxx` | +| `cublas_to_mublas_name(name)` | Convert `cublasXxx` → `mublasXxx` | +| `curand_to_murand_name(name)` | Convert `curandXxx` → `murandXxx` | **Note**: `torch.cuda.is_available()` is intentionally NOT redirected — it returns `False` on MUSA. This allows proper platform detection. Use `torch.musa.is_available()` or `is_musa()` function instead. +**Note**: The name conversion utilities are exported for manual use, but `ctypes.CDLL` is automatically patched to translate function names when loading MUSA libraries. + ## C++ Extension Symbol Mapping When building C++ extensions, torchada automatically translates CUDA symbols to MUSA: diff --git a/README_CN.md b/README_CN.md index 5f7f4ad..ca07ebb 100644 --- a/README_CN.md +++ b/README_CN.md @@ -61,6 +61,7 @@ torch.cuda.synchronize() | 分布式训练 | `dist.init_process_group(backend='nccl')` → 使用 MCCL | | torch.compile | `torch.compile(model)` 支持所有后端 | | C++ 扩展 | `CUDAExtension`, `BuildExtension`, `load()` | +| ctypes 库加载 | `ctypes.CDLL` 使用 CUDA 函数名 → 自动转换为 MUSA | ## 示例 @@ -146,6 +147,21 @@ with torch.profiler.profile( model(x) ``` +### ctypes 库加载 + +```python +import torchada +import ctypes + +# 使用 CUDA 函数名加载 MUSA 运行时库 +lib = ctypes.CDLL("libmusart.so") +func = lib.cudaMalloc # 自动转换为 musaMalloc + +# 同样适用于 MCCL +nccl_lib = ctypes.CDLL("libmccl.so") +func = nccl_lib.ncclAllReduce # 自动转换为 mcclAllReduce +``` + ## 平台检测 ```python @@ -192,9 +208,15 @@ if torchada.is_gpu_device(device): # 在 CUDA 和 MUSA 上都能工作 | `is_cuda_platform()` | 在 CUDA 上运行时返回 True | | `is_gpu_device(device)` | 设备是 CUDA 或 MUSA 时返回 True | | `CUDA_HOME` | CUDA/MUSA 安装路径 | +| `cuda_to_musa_name(name)` | 转换 `cudaXxx` → `musaXxx` | +| `nccl_to_mccl_name(name)` | 转换 `ncclXxx` → `mcclXxx` | +| `cublas_to_mublas_name(name)` | 转换 `cublasXxx` → `mublasXxx` | +| `curand_to_murand_name(name)` | 转换 `curandXxx` → `murandXxx` | **注意**:`torch.cuda.is_available()` 故意没有重定向 — 在 MUSA 上返回 `False`。这是为了支持正确的平台检测。请改用 `torch.musa.is_available()` 或 `is_musa()` 函数。 +**注意**:名称转换工具函数可供手动使用,但 `ctypes.CDLL` 已自动打补丁,加载 MUSA 库时会自动转换函数名。 + ## C++ 扩展符号映射 构建 C++ 扩展时,torchada 会自动将 CUDA 符号转换为 MUSA: diff --git a/src/torchada/__init__.py b/src/torchada/__init__.py index 745cc67..67bc790 100644 --- a/src/torchada/__init__.py +++ b/src/torchada/__init__.py @@ -38,6 +38,12 @@ is_gpu_device, is_musa_platform, ) +from ._runtime import ( + cublas_to_mublas_name, + cuda_to_musa_name, + curand_to_murand_name, + nccl_to_mccl_name, +) from .utils.cpp_extension import CUDA_HOME # Automatically apply patches on import @@ -89,4 +95,9 @@ def get_backend(): "get_original_init_process_group", # C++ Extension building "CUDA_HOME", + # Runtime name conversion utilities + "cuda_to_musa_name", + "nccl_to_mccl_name", + "cublas_to_mublas_name", + "curand_to_murand_name", ] diff --git a/src/torchada/_patch.py b/src/torchada/_patch.py index 4ab9b4e..834b888 100644 --- a/src/torchada/_patch.py +++ b/src/torchada/_patch.py @@ -1065,6 +1065,153 @@ def _patch_autotune_process(): autotune_process.CUDA_VISIBLE_DEVICES = "MUSA_VISIBLE_DEVICES" +class _CDLLWrapper: + """ + Wrapper for ctypes.CDLL that automatically translates CUDA/NCCL function names + to MUSA/MCCL equivalents when accessing library functions. + + This allows code that uses ctypes to load CUDA libraries (libcudart, libnccl) and + access CUDA-named functions to work transparently on MUSA without code changes. + + Example: + # Original code uses CUDA function names: + lib = ctypes.CDLL("libmusart.so") + func = lib.cudaIpcOpenMemHandle # Automatically translates to musaIpcOpenMemHandle + + lib = ctypes.CDLL("libmccl.so") + func = lib.ncclAllReduce # Automatically translates to mcclAllReduce + """ + + # Detect library type from filename patterns + _MUSART_PATTERNS = ("libmusart", "musart.so", "libmusa_runtime") + _MCCL_PATTERNS = ("libmccl", "mccl.so") + _MUBLAS_PATTERNS = ("libmublas", "mublas.so") + _MURAND_PATTERNS = ("libmurand", "murand.so") + + def __init__(self, cdll_instance, lib_path: str): + # Store the original CDLL instance + object.__setattr__(self, "_cdll", cdll_instance) + object.__setattr__(self, "_lib_path", lib_path) + object.__setattr__(self, "_lib_type", self._detect_lib_type(lib_path)) + + def _detect_lib_type(self, lib_path: str) -> str: + """Detect the type of library from its path.""" + lib_path_lower = lib_path.lower() + if any(p in lib_path_lower for p in self._MUSART_PATTERNS): + return "musart" + elif any(p in lib_path_lower for p in self._MCCL_PATTERNS): + return "mccl" + elif any(p in lib_path_lower for p in self._MUBLAS_PATTERNS): + return "mublas" + elif any(p in lib_path_lower for p in self._MURAND_PATTERNS): + return "murand" + return "unknown" + + def _translate_name(self, name: str) -> str: + """Translate CUDA/NCCL function name to MUSA/MCCL equivalent.""" + lib_type = object.__getattribute__(self, "_lib_type") + + if lib_type == "musart": + # cudaXxx -> musaXxx + if name.startswith("cuda"): + return "musa" + name[4:] + elif lib_type == "mccl": + # ncclXxx -> mcclXxx + if name.startswith("nccl"): + return "mccl" + name[4:] + elif lib_type == "mublas": + # cublasXxx -> mublasXxx + if name.startswith("cublas"): + return "mublas" + name[6:] + elif lib_type == "murand": + # curandXxx -> murandXxx + if name.startswith("curand"): + return "murand" + name[6:] + + return name + + def __getattr__(self, name: str): + cdll = object.__getattribute__(self, "_cdll") + translated_name = self._translate_name(name) + return getattr(cdll, translated_name) + + def __setattr__(self, name: str, value): + cdll = object.__getattribute__(self, "_cdll") + translated_name = self._translate_name(name) + setattr(cdll, translated_name, value) + + def __getitem__(self, name: str): + cdll = object.__getattribute__(self, "_cdll") + translated_name = self._translate_name(name) + return cdll[translated_name] + + +# Store original ctypes.CDLL for patching +_original_ctypes_CDLL = None + + +@patch_function +def _patch_ctypes_cdll(): + """ + Patch ctypes.CDLL to automatically translate CUDA/NCCL function names to MUSA/MCCL. + + This allows code that uses ctypes to directly call CUDA runtime or NCCL functions + (like sglang's cuda_wrapper.py and pynccl.py) to work transparently on MUSA + without requiring code changes. + + When loading MUSA libraries (libmusart.so, libmccl.so, etc.), the returned CDLL + wrapper will automatically translate function name lookups: + - cudaXxx -> musaXxx (for libmusart) + - ncclXxx -> mcclXxx (for libmccl) + - cublasXxx -> mublasXxx (for libmublas) + - curandXxx -> murandXxx (for libmurand) + + Example (in sglang): + lib = ctypes.CDLL("libmusart.so") + # This will automatically find musaIpcOpenMemHandle: + func = lib.cudaIpcOpenMemHandle + """ + import ctypes + + global _original_ctypes_CDLL + + # Only patch once + if _original_ctypes_CDLL is not None: + return + + _original_ctypes_CDLL = ctypes.CDLL + + class PatchedCDLL: + """Patched CDLL that wraps MUSA libraries with function name translation.""" + + def __new__(cls, name, *args, **kwargs): + # Create the original CDLL instance + cdll_instance = _original_ctypes_CDLL(name, *args, **kwargs) + + # Check if this is a MUSA library that needs wrapping + name_str = str(name) if name else "" + if any( + pattern in name_str.lower() + for pattern in ( + "libmusart", + "musart.so", + "libmusa_runtime", + "libmccl", + "mccl.so", + "libmublas", + "mublas.so", + "libmurand", + "murand.so", + ) + ): + return _CDLLWrapper(cdll_instance, name_str) + + # For non-MUSA libraries, return the original CDLL instance + return cdll_instance + + ctypes.CDLL = PatchedCDLL + + def apply_patches(): """ Apply all necessary patches for CUDA to MUSA translation. @@ -1087,6 +1234,10 @@ def apply_patches(): - torch.amp.autocast(device_type='cuda') -> 'musa' - torch.utils.cpp_extension (CUDAExtension, BuildExtension) -> MUSA versions - torch._inductor.autotune_process.CUDA_VISIBLE_DEVICES -> MUSA_VISIBLE_DEVICES + - ctypes.CDLL function name translation for MUSA libraries: + - cudaXxx -> musaXxx (for libmusart) + - ncclXxx -> mcclXxx (for libmccl) + - cublasXxx -> mublasXxx, curandXxx -> murandXxx (for libmublas, libmurand) This function should be called once at import time. diff --git a/src/torchada/_runtime.py b/src/torchada/_runtime.py new file mode 100644 index 0000000..128be47 --- /dev/null +++ b/src/torchada/_runtime.py @@ -0,0 +1,132 @@ +""" +Runtime name conversion utilities for CUDA to MUSA. + +This module provides utility functions for converting CUDA function/library +names to their MUSA equivalents at runtime. + +Note: torchada automatically patches ctypes.CDLL to translate function names +when loading MUSA libraries (libmusart.so, libmccl.so, etc.). Most users don't +need to use these functions directly - just import torchada and use ctypes +normally with CUDA function names. + +Example of automatic patching (no code changes needed): + import torchada + import ctypes + + # Load MUSA runtime library + lib = ctypes.CDLL("libmusart.so") + + # Access using CUDA function names - automatically translated! + func = lib.cudaIpcOpenMemHandle # -> musaIpcOpenMemHandle + +These utility functions are exported for manual use if needed: + from torchada import cuda_to_musa_name, nccl_to_mccl_name + + musa_name = cuda_to_musa_name("cudaIpcOpenMemHandle") # -> "musaIpcOpenMemHandle" + mccl_name = nccl_to_mccl_name("ncclAllReduce") # -> "mcclAllReduce" +""" + + +def cuda_to_musa_name(name: str) -> str: + """ + Convert a CUDA function/symbol name to its MUSA equivalent. + + This handles the common naming convention where CUDA functions start with + "cuda" and MUSA equivalents start with "musa". + + Args: + name: The CUDA function name (e.g., "cudaIpcOpenMemHandle") + + Returns: + The MUSA equivalent name (e.g., "musaIpcOpenMemHandle") + + Examples: + >>> cuda_to_musa_name("cudaMalloc") + 'musaMalloc' + >>> cuda_to_musa_name("cudaIpcOpenMemHandle") + 'musaIpcOpenMemHandle' + >>> cuda_to_musa_name("cudaError_t") + 'musaError_t' + >>> cuda_to_musa_name("someOtherFunc") + 'someOtherFunc' + """ + if name.startswith("cuda"): + return "musa" + name[4:] + return name + + +def nccl_to_mccl_name(name: str) -> str: + """ + Convert an NCCL function/symbol name to its MCCL equivalent. + + This handles the common naming convention where NCCL functions start with + "nccl" and MCCL equivalents start with "mccl". + + Args: + name: The NCCL function name (e.g., "ncclAllReduce") + + Returns: + The MCCL equivalent name (e.g., "mcclAllReduce") + + Examples: + >>> nccl_to_mccl_name("ncclAllReduce") + 'mcclAllReduce' + >>> nccl_to_mccl_name("ncclCommInitRank") + 'mcclCommInitRank' + >>> nccl_to_mccl_name("ncclUniqueId") + 'mcclUniqueId' + >>> nccl_to_mccl_name("someOtherFunc") + 'someOtherFunc' + """ + if name.startswith("nccl"): + return "mccl" + name[4:] + return name + + +def cublas_to_mublas_name(name: str) -> str: + """ + Convert a cuBLAS function/symbol name to its muBLAS equivalent. + + This handles the common naming convention where cuBLAS functions start with + "cublas" and muBLAS equivalents start with "mublas". + + Args: + name: The cuBLAS function name (e.g., "cublasCreate") + + Returns: + The muBLAS equivalent name (e.g., "mublasCreate") + + Examples: + >>> cublas_to_mublas_name("cublasCreate") + 'mublasCreate' + >>> cublas_to_mublas_name("cublasSgemm") + 'mublasSgemm' + >>> cublas_to_mublas_name("someOtherFunc") + 'someOtherFunc' + """ + if name.startswith("cublas"): + return "mublas" + name[6:] + return name + + +def curand_to_murand_name(name: str) -> str: + """ + Convert a cuRAND function/symbol name to its muRAND equivalent. + + Args: + name: The cuRAND function name (e.g., "curandCreate") + + Returns: + The muRAND equivalent name (e.g., "murandCreate") + + Examples: + >>> curand_to_murand_name("curandCreate") + 'murandCreate' + >>> curand_to_murand_name("curand_init") + 'murand_init' + >>> curand_to_murand_name("someOtherFunc") + 'someOtherFunc' + """ + if name.startswith("curand"): + return "murand" + name[6:] + return name diff --git a/tests/test_mappings.py b/tests/test_mappings.py index 3aef17f..58a2036 100644 --- a/tests/test_mappings.py +++ b/tests/test_mappings.py @@ -667,3 +667,278 @@ def test_include_specific_paths(self): # And that generic at::cuda also exists assert "at::cuda" in _MAPPING_RULE + + +class TestRuntimeNameConversion: + """Test runtime name conversion utilities.""" + + def test_cuda_to_musa_name(self): + """Test CUDA to MUSA function name conversion.""" + from torchada import cuda_to_musa_name + + # Basic conversions + assert cuda_to_musa_name("cudaMalloc") == "musaMalloc" + assert cuda_to_musa_name("cudaFree") == "musaFree" + assert cuda_to_musa_name("cudaIpcOpenMemHandle") == "musaIpcOpenMemHandle" + assert cuda_to_musa_name("cudaIpcGetMemHandle") == "musaIpcGetMemHandle" + assert cuda_to_musa_name("cudaMemset") == "musaMemset" + assert cuda_to_musa_name("cudaError_t") == "musaError_t" + + # Non-cuda names should be unchanged + assert cuda_to_musa_name("someOtherFunc") == "someOtherFunc" + assert cuda_to_musa_name("malloc") == "malloc" + + def test_nccl_to_mccl_name(self): + """Test NCCL to MCCL function name conversion.""" + from torchada import nccl_to_mccl_name + + # Basic conversions + assert nccl_to_mccl_name("ncclAllReduce") == "mcclAllReduce" + assert nccl_to_mccl_name("ncclCommInitRank") == "mcclCommInitRank" + assert nccl_to_mccl_name("ncclBroadcast") == "mcclBroadcast" + assert nccl_to_mccl_name("ncclUniqueId") == "mcclUniqueId" + assert nccl_to_mccl_name("ncclGetErrorString") == "mcclGetErrorString" + + # Non-nccl names should be unchanged + assert nccl_to_mccl_name("someOtherFunc") == "someOtherFunc" + + def test_cublas_to_mublas_name(self): + """Test cuBLAS to muBLAS function name conversion.""" + from torchada import cublas_to_mublas_name + + assert cublas_to_mublas_name("cublasCreate") == "mublasCreate" + assert cublas_to_mublas_name("cublasSgemm") == "mublasSgemm" + assert cublas_to_mublas_name("cublasDestroy") == "mublasDestroy" + + # Non-cublas names should be unchanged + assert cublas_to_mublas_name("someOtherFunc") == "someOtherFunc" + + def test_curand_to_murand_name(self): + """Test cuRAND to muRAND function name conversion.""" + from torchada import curand_to_murand_name + + assert curand_to_murand_name("curandCreate") == "murandCreate" + assert curand_to_murand_name("curand_init") == "murand_init" + + # Non-curand names should be unchanged + assert curand_to_murand_name("someOtherFunc") == "someOtherFunc" + + +class TestCDLLWrapper: + """Test ctypes.CDLL wrapper for automatic function name translation. + + These tests load actual MUSA libraries from /usr/local/musa/lib/ and verify + that CUDA function names are automatically translated to MUSA equivalents. + """ + + MUSA_LIB_PATH = "/usr/local/musa/lib" + + def test_cdll_wrapper_class_exists(self): + """Test that _CDLLWrapper class is available.""" + from torchada._patch import _CDLLWrapper + + assert _CDLLWrapper is not None + + @pytest.mark.musa + def test_libmusart_cuda_to_musa_translation(self): + """Test that CUDA function names are translated when loading libmusart.so.""" + import ctypes + import os + + import torchada # noqa: F401 - Apply patches + + lib_path = os.path.join(self.MUSA_LIB_PATH, "libmusart.so") + if not os.path.exists(lib_path): + pytest.skip(f"libmusart.so not found at {lib_path}") + + # Load the library using patched ctypes.CDLL + lib = ctypes.CDLL(lib_path) + + # Access using CUDA function names - should be translated to MUSA + # These should NOT raise AttributeError because they get translated + func = lib.cudaMalloc + assert func is not None + + func = lib.cudaFree + assert func is not None + + func = lib.cudaGetDevice + assert func is not None + + func = lib.cudaIpcOpenMemHandle + assert func is not None + + func = lib.cudaIpcGetMemHandle + assert func is not None + + @pytest.mark.musa + def test_libmccl_nccl_to_mccl_translation(self): + """Test that NCCL function names are translated when loading libmccl.so.""" + import ctypes + import os + + import torchada # noqa: F401 - Apply patches + + lib_path = os.path.join(self.MUSA_LIB_PATH, "libmccl.so") + if not os.path.exists(lib_path): + pytest.skip(f"libmccl.so not found at {lib_path}") + + # Load the library using patched ctypes.CDLL + lib = ctypes.CDLL(lib_path) + + # Access using NCCL function names - should be translated to MCCL + func = lib.ncclAllReduce + assert func is not None + + func = lib.ncclBroadcast + assert func is not None + + func = lib.ncclCommInitRank + assert func is not None + + @pytest.mark.musa + def test_libmublas_cublas_to_mublas_translation(self): + """Test that cuBLAS function names are translated when loading libmublas.so.""" + import ctypes + import os + + import torchada # noqa: F401 - Apply patches + + lib_path = os.path.join(self.MUSA_LIB_PATH, "libmublas.so") + if not os.path.exists(lib_path): + pytest.skip(f"libmublas.so not found at {lib_path}") + + # Load the library using patched ctypes.CDLL + lib = ctypes.CDLL(lib_path) + + # Access using cuBLAS function names - should be translated to muBLAS + # mublasCreate is the actual function name in libmublas.so + func = lib.cublasCreate + assert func is not None + + func = lib.cublasDestroy + assert func is not None + + @pytest.mark.musa + def test_libmurand_curand_to_murand_translation(self): + """Test that cuRAND function names are translated when loading libmurand.so.""" + import ctypes + import os + + import torchada # noqa: F401 - Apply patches + + lib_path = os.path.join(self.MUSA_LIB_PATH, "libmurand.so") + if not os.path.exists(lib_path): + pytest.skip(f"libmurand.so not found at {lib_path}") + + # Load the library using patched ctypes.CDLL + lib = ctypes.CDLL(lib_path) + + # Access using cuRAND function names - should be translated to muRAND + func = lib.curandCreateGenerator + assert func is not None + + @pytest.mark.musa + def test_non_musa_lib_no_translation(self): + """Test that non-MUSA libraries don't get function name translation.""" + import ctypes + + import torchada # noqa: F401 - Apply patches + + # Load a standard library that exists on all systems + try: + lib = ctypes.CDLL("libc.so.6") + except OSError: + pytest.skip("libc.so.6 not found") + + # This should NOT be wrapped, so accessing cudaMalloc should fail + # (libc doesn't have cudaMalloc or musaMalloc) + with pytest.raises(AttributeError): + _ = lib.cudaMalloc + + @pytest.mark.musa + def test_sglang_cuda_wrapper_pattern(self): + """Test that sglang's cuda_wrapper.py pattern works seamlessly with torchada. + + This test simulates the approach used in sglang's cuda_wrapper.py: + https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py + + The key insight is that sglang uses CUDA function names (cudaMalloc, cudaFree, etc.) + when accessing functions from the library. With torchada's ctypes.CDLL patch, + these names are automatically translated to MUSA equivalents. + """ + import ctypes + import os + from dataclasses import dataclass + from typing import Any, List + + import torchada # noqa: F401 - Apply patches + + lib_path = os.path.join(self.MUSA_LIB_PATH, "libmusart.so") + if not os.path.exists(lib_path): + pytest.skip(f"libmusart.so not found at {lib_path}") + + # === Types from sglang's cuda_wrapper.py === + cudaError_t = ctypes.c_int + + @dataclass + class Function: + name: str + restype: Any + argtypes: List[Any] + + # === Subset of exported functions (same as sglang) === + exported_functions = [ + # cudaError_t cudaSetDevice ( int device ) + Function("cudaSetDevice", cudaError_t, [ctypes.c_int]), + # cudaError_t cudaDeviceSynchronize ( void ) + Function("cudaDeviceSynchronize", cudaError_t, []), + # cudaError_t cudaMalloc ( void** devPtr, size_t size ) + Function( + "cudaMalloc", + cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t], + ), + # cudaError_t cudaFree ( void* devPtr ) + Function("cudaFree", cudaError_t, [ctypes.c_void_p]), + ] + + # === Load library using sglang's pattern === + # sglang does: lib = ctypes.CDLL(so_file) + lib = ctypes.CDLL(lib_path) + + # === Access functions using CUDA names (sglang's pattern) === + # sglang does: f = getattr(self.lib, func.name) + # With torchada, this automatically translates cudaXxx -> musaXxx + funcs = {} + for func in exported_functions: + # This is the key line - sglang uses CUDA names here + f = getattr(lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + funcs[func.name] = f + + # Verify all functions were loaded successfully + assert "cudaSetDevice" in funcs + assert "cudaDeviceSynchronize" in funcs + assert "cudaMalloc" in funcs + assert "cudaFree" in funcs + + # === Actually call the functions to verify they work === + # Set device 0 + result = funcs["cudaSetDevice"](0) + assert result == 0, f"cudaSetDevice failed with error {result}" + + # Allocate memory + devPtr = ctypes.c_void_p() + result = funcs["cudaMalloc"](ctypes.byref(devPtr), 1024) + assert result == 0, f"cudaMalloc failed with error {result}" + assert devPtr.value is not None + + # Free memory + result = funcs["cudaFree"](devPtr) + assert result == 0, f"cudaFree failed with error {result}" + + # Synchronize + result = funcs["cudaDeviceSynchronize"]() + assert result == 0, f"cudaDeviceSynchronize failed with error {result}" From 27559db3c49c07f56d2f92fd9d0cd93f64b2d1c9 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Wed, 28 Jan 2026 15:12:29 +0800 Subject: [PATCH 2/2] Bump version Signed-off-by: Xiaodong Ye --- README.md | 2 +- README_CN.md | 2 +- pyproject.toml | 2 +- src/torchada/__init__.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 5a8a723..3118879 100644 --- a/README.md +++ b/README.md @@ -238,7 +238,7 @@ See `src/torchada/_mapping.py` for the complete mapping table (380+ mappings). ``` # pyproject.toml or requirements.txt -torchada>=0.1.24 +torchada>=0.1.25 ``` ### Step 2: Conditional Import diff --git a/README_CN.md b/README_CN.md index ca07ebb..c82b7de 100644 --- a/README_CN.md +++ b/README_CN.md @@ -238,7 +238,7 @@ if torchada.is_gpu_device(device): # 在 CUDA 和 MUSA 上都能工作 ``` # pyproject.toml 或 requirements.txt -torchada>=0.1.24 +torchada>=0.1.25 ``` ### 步骤 2:条件导入 diff --git a/pyproject.toml b/pyproject.toml index 0cbf4b8..b1f6596 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "torchada" -version = "0.1.24" +version = "0.1.25" description = "Adapter package for torch_musa to act exactly like PyTorch CUDA" readme = "README.md" license = {text = "MIT"} diff --git a/src/torchada/__init__.py b/src/torchada/__init__.py index 67bc790..cfc1d45 100644 --- a/src/torchada/__init__.py +++ b/src/torchada/__init__.py @@ -23,7 +23,7 @@ from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CUDA_HOME """ -__version__ = "0.1.24" +__version__ = "0.1.25" from . import cuda, utils from ._patch import apply_patches, get_original_init_process_group, is_patched