Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ That's it! All `torch.cuda.*` APIs are automatically redirected to `torch.musa.*
| Distributed | `dist.init_process_group(backend='nccl')` → uses MCCL |
| torch.compile | `torch.compile(model)` with all backends |
| C++ Extensions | `CUDAExtension`, `BuildExtension`, `load()` |
| ctypes Libraries | `ctypes.CDLL` with CUDA function names → MUSA equivalents |

## Examples

Expand Down Expand Up @@ -146,6 +147,21 @@ with torch.profiler.profile(
model(x)
```

### ctypes Library Loading

```python
import torchada
import ctypes

# Load MUSA runtime library with CUDA function names
lib = ctypes.CDLL("libmusart.so")
func = lib.cudaMalloc # Automatically translates to musaMalloc

# Works with MCCL too
nccl_lib = ctypes.CDLL("libmccl.so")
func = nccl_lib.ncclAllReduce # Automatically translates to mcclAllReduce
```

## Platform Detection

```python
Expand Down Expand Up @@ -192,9 +208,15 @@ if torchada.is_gpu_device(device): # Works on both CUDA and MUSA
| `is_cuda_platform()` | Returns True if running on CUDA |
| `is_gpu_device(device)` | Returns True if device is CUDA or MUSA |
| `CUDA_HOME` | Path to CUDA/MUSA installation |
| `cuda_to_musa_name(name)` | Convert `cudaXxx` → `musaXxx` |
| `nccl_to_mccl_name(name)` | Convert `ncclXxx` → `mcclXxx` |
| `cublas_to_mublas_name(name)` | Convert `cublasXxx` → `mublasXxx` |
| `curand_to_murand_name(name)` | Convert `curandXxx` → `murandXxx` |

**Note**: `torch.cuda.is_available()` is intentionally NOT redirected — it returns `False` on MUSA. This allows proper platform detection. Use `torch.musa.is_available()` or `is_musa()` function instead.

**Note**: The name conversion utilities are exported for manual use, but `ctypes.CDLL` is automatically patched to translate function names when loading MUSA libraries.

## C++ Extension Symbol Mapping

When building C++ extensions, torchada automatically translates CUDA symbols to MUSA:
Expand All @@ -216,7 +238,7 @@ See `src/torchada/_mapping.py` for the complete mapping table (380+ mappings).

```
# pyproject.toml or requirements.txt
torchada>=0.1.24
torchada>=0.1.25
```

### Step 2: Conditional Import
Expand Down
24 changes: 23 additions & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ torch.cuda.synchronize()
| 分布式训练 | `dist.init_process_group(backend='nccl')` → 使用 MCCL |
| torch.compile | `torch.compile(model)` 支持所有后端 |
| C++ 扩展 | `CUDAExtension`, `BuildExtension`, `load()` |
| ctypes 库加载 | `ctypes.CDLL` 使用 CUDA 函数名 → 自动转换为 MUSA |

## 示例

Expand Down Expand Up @@ -146,6 +147,21 @@ with torch.profiler.profile(
model(x)
```

### ctypes 库加载

```python
import torchada
import ctypes

# 使用 CUDA 函数名加载 MUSA 运行时库
lib = ctypes.CDLL("libmusart.so")
func = lib.cudaMalloc # 自动转换为 musaMalloc

# 同样适用于 MCCL
nccl_lib = ctypes.CDLL("libmccl.so")
func = nccl_lib.ncclAllReduce # 自动转换为 mcclAllReduce
```

## 平台检测

```python
Expand Down Expand Up @@ -192,9 +208,15 @@ if torchada.is_gpu_device(device): # 在 CUDA 和 MUSA 上都能工作
| `is_cuda_platform()` | 在 CUDA 上运行时返回 True |
| `is_gpu_device(device)` | 设备是 CUDA 或 MUSA 时返回 True |
| `CUDA_HOME` | CUDA/MUSA 安装路径 |
| `cuda_to_musa_name(name)` | 转换 `cudaXxx` → `musaXxx` |
| `nccl_to_mccl_name(name)` | 转换 `ncclXxx` → `mcclXxx` |
| `cublas_to_mublas_name(name)` | 转换 `cublasXxx` → `mublasXxx` |
| `curand_to_murand_name(name)` | 转换 `curandXxx` → `murandXxx` |

**注意**:`torch.cuda.is_available()` 故意没有重定向 — 在 MUSA 上返回 `False`。这是为了支持正确的平台检测。请改用 `torch.musa.is_available()` 或 `is_musa()` 函数。

**注意**:名称转换工具函数可供手动使用,但 `ctypes.CDLL` 已自动打补丁,加载 MUSA 库时会自动转换函数名。

## C++ 扩展符号映射

构建 C++ 扩展时,torchada 会自动将 CUDA 符号转换为 MUSA:
Expand All @@ -216,7 +238,7 @@ if torchada.is_gpu_device(device): # 在 CUDA 和 MUSA 上都能工作

```
# pyproject.toml 或 requirements.txt
torchada>=0.1.24
torchada>=0.1.25
```

### 步骤 2:条件导入
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "torchada"
version = "0.1.24"
version = "0.1.25"
description = "Adapter package for torch_musa to act exactly like PyTorch CUDA"
readme = "README.md"
license = {text = "MIT"}
Expand Down
13 changes: 12 additions & 1 deletion src/torchada/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CUDA_HOME
"""

__version__ = "0.1.24"
__version__ = "0.1.25"

from . import cuda, utils
from ._patch import apply_patches, get_original_init_process_group, is_patched
Expand All @@ -38,6 +38,12 @@
is_gpu_device,
is_musa_platform,
)
from ._runtime import (
cublas_to_mublas_name,
cuda_to_musa_name,
curand_to_murand_name,
nccl_to_mccl_name,
)
from .utils.cpp_extension import CUDA_HOME

# Automatically apply patches on import
Expand Down Expand Up @@ -89,4 +95,9 @@ def get_backend():
"get_original_init_process_group",
# C++ Extension building
"CUDA_HOME",
# Runtime name conversion utilities
"cuda_to_musa_name",
"nccl_to_mccl_name",
"cublas_to_mublas_name",
"curand_to_murand_name",
]
151 changes: 151 additions & 0 deletions src/torchada/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,153 @@ def _patch_autotune_process():
autotune_process.CUDA_VISIBLE_DEVICES = "MUSA_VISIBLE_DEVICES"


class _CDLLWrapper:
"""
Wrapper for ctypes.CDLL that automatically translates CUDA/NCCL function names
to MUSA/MCCL equivalents when accessing library functions.

This allows code that uses ctypes to load CUDA libraries (libcudart, libnccl) and
access CUDA-named functions to work transparently on MUSA without code changes.

Example:
# Original code uses CUDA function names:
lib = ctypes.CDLL("libmusart.so")
func = lib.cudaIpcOpenMemHandle # Automatically translates to musaIpcOpenMemHandle

lib = ctypes.CDLL("libmccl.so")
func = lib.ncclAllReduce # Automatically translates to mcclAllReduce
"""

# Detect library type from filename patterns
_MUSART_PATTERNS = ("libmusart", "musart.so", "libmusa_runtime")
_MCCL_PATTERNS = ("libmccl", "mccl.so")
_MUBLAS_PATTERNS = ("libmublas", "mublas.so")
_MURAND_PATTERNS = ("libmurand", "murand.so")

def __init__(self, cdll_instance, lib_path: str):
# Store the original CDLL instance
object.__setattr__(self, "_cdll", cdll_instance)
object.__setattr__(self, "_lib_path", lib_path)
object.__setattr__(self, "_lib_type", self._detect_lib_type(lib_path))

def _detect_lib_type(self, lib_path: str) -> str:
"""Detect the type of library from its path."""
lib_path_lower = lib_path.lower()
if any(p in lib_path_lower for p in self._MUSART_PATTERNS):
return "musart"
elif any(p in lib_path_lower for p in self._MCCL_PATTERNS):
return "mccl"
elif any(p in lib_path_lower for p in self._MUBLAS_PATTERNS):
return "mublas"
elif any(p in lib_path_lower for p in self._MURAND_PATTERNS):
return "murand"
return "unknown"

def _translate_name(self, name: str) -> str:
"""Translate CUDA/NCCL function name to MUSA/MCCL equivalent."""
lib_type = object.__getattribute__(self, "_lib_type")

if lib_type == "musart":
# cudaXxx -> musaXxx
if name.startswith("cuda"):
return "musa" + name[4:]
elif lib_type == "mccl":
# ncclXxx -> mcclXxx
if name.startswith("nccl"):
return "mccl" + name[4:]
elif lib_type == "mublas":
# cublasXxx -> mublasXxx
if name.startswith("cublas"):
return "mublas" + name[6:]
elif lib_type == "murand":
# curandXxx -> murandXxx
if name.startswith("curand"):
return "murand" + name[6:]

return name

def __getattr__(self, name: str):
cdll = object.__getattribute__(self, "_cdll")
translated_name = self._translate_name(name)
return getattr(cdll, translated_name)

def __setattr__(self, name: str, value):
cdll = object.__getattribute__(self, "_cdll")
translated_name = self._translate_name(name)
setattr(cdll, translated_name, value)

def __getitem__(self, name: str):
cdll = object.__getattribute__(self, "_cdll")
translated_name = self._translate_name(name)
return cdll[translated_name]


# Store original ctypes.CDLL for patching
_original_ctypes_CDLL = None


@patch_function
def _patch_ctypes_cdll():
"""
Patch ctypes.CDLL to automatically translate CUDA/NCCL function names to MUSA/MCCL.

This allows code that uses ctypes to directly call CUDA runtime or NCCL functions
(like sglang's cuda_wrapper.py and pynccl.py) to work transparently on MUSA
without requiring code changes.

When loading MUSA libraries (libmusart.so, libmccl.so, etc.), the returned CDLL
wrapper will automatically translate function name lookups:
- cudaXxx -> musaXxx (for libmusart)
- ncclXxx -> mcclXxx (for libmccl)
- cublasXxx -> mublasXxx (for libmublas)
- curandXxx -> murandXxx (for libmurand)

Example (in sglang):
lib = ctypes.CDLL("libmusart.so")
# This will automatically find musaIpcOpenMemHandle:
func = lib.cudaIpcOpenMemHandle
"""
import ctypes

global _original_ctypes_CDLL

# Only patch once
if _original_ctypes_CDLL is not None:
return

_original_ctypes_CDLL = ctypes.CDLL

class PatchedCDLL:
"""Patched CDLL that wraps MUSA libraries with function name translation."""

def __new__(cls, name, *args, **kwargs):
# Create the original CDLL instance
cdll_instance = _original_ctypes_CDLL(name, *args, **kwargs)

# Check if this is a MUSA library that needs wrapping
name_str = str(name) if name else ""
if any(
pattern in name_str.lower()
for pattern in (
"libmusart",
"musart.so",
"libmusa_runtime",
"libmccl",
"mccl.so",
"libmublas",
"mublas.so",
"libmurand",
"murand.so",
)
):
return _CDLLWrapper(cdll_instance, name_str)

# For non-MUSA libraries, return the original CDLL instance
return cdll_instance

ctypes.CDLL = PatchedCDLL


def apply_patches():
"""
Apply all necessary patches for CUDA to MUSA translation.
Expand All @@ -1087,6 +1234,10 @@ def apply_patches():
- torch.amp.autocast(device_type='cuda') -> 'musa'
- torch.utils.cpp_extension (CUDAExtension, BuildExtension) -> MUSA versions
- torch._inductor.autotune_process.CUDA_VISIBLE_DEVICES -> MUSA_VISIBLE_DEVICES
- ctypes.CDLL function name translation for MUSA libraries:
- cudaXxx -> musaXxx (for libmusart)
- ncclXxx -> mcclXxx (for libmccl)
- cublasXxx -> mublasXxx, curandXxx -> murandXxx (for libmublas, libmurand)

This function should be called once at import time.

Expand Down
Loading