diff --git a/AGENTS.md b/AGENTS.md index ce89a99..d3a7f7a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -21,6 +21,10 @@ src/torchada/ ├── _patch.py # All patching logic (~1100 lines) ├── _platform.py # Platform detection utilities ├── _mapping.py # CUDA→MUSA symbol mappings for C++ extensions +├── _cpp_ops.py # C++ operator overrides infrastructure +├── csrc/ # C++ source files for operator overrides +│ ├── ops.h # Header with utilities and examples +│ └── ops.cpp # Main C++ source with Python bindings ├── cuda/ # CUDA module compatibility └── utils/cpp_extension.py # CUDAExtension wrapper tests/ @@ -68,7 +72,7 @@ pytest tests/ --tb=short **Test Markers** (defined in `conftest.py`): - `@pytest.mark.musa` - Requires MUSA platform -- `@pytest.mark.cuda` - Requires CUDA platform +- `@pytest.mark.cuda` - Requires CUDA platform - `@pytest.mark.gpu` - Requires any GPU - `@pytest.mark.slow` - Slow tests @@ -138,6 +142,32 @@ import uuid lib_name = f"test_lib_{uuid.uuid4().hex[:8]}" ``` +## C++ Operator Overrides + +torchada supports overriding ATen operators at the C++ level for better performance. + +**See [docs/custom_musa_ops.md](docs/custom_musa_ops.md) for detailed documentation.** + +**Quick start**: +```bash +export TORCHADA_ENABLE_CPP_OPS=1 +``` + +**Adding a new operator override**: + +1. Edit `src/torchada/csrc/musa_ops.mu` for MUSA kernels (or `ops.cpp` for pure C++) + +2. Register using `TORCH_LIBRARY_IMPL(aten, PrivateUse1, m)` + +3. The extension is JIT-compiled on first use + +**Environment variables**: +- `TORCHADA_ENABLE_CPP_OPS=1` - Enable C++ operator overrides +- `TORCHADA_CPP_OPS_VERBOSE=1` - Show compilation output +- `TORCHADA_DEBUG_CPP_OPS=1` - Log operator calls +- `TORCHADA_DISABLE_OP_OVERRIDE_=1` - Disable specific operator override +- `MTGPU_TARGET=mp_XX` - Override GPU architecture (auto-detected via `musaInfo`) + ## Security Considerations - All patches are applied at import time via `apply_patches()` diff --git a/docs/custom_musa_ops.md b/docs/custom_musa_ops.md new file mode 100644 index 0000000..376eaa7 --- /dev/null +++ b/docs/custom_musa_ops.md @@ -0,0 +1,236 @@ +# Writing Custom MUSA Operators in torchada + +This guide explains how to write custom MUSA C++ operators that override torch_musa's default ATen implementations. + +## Overview + +torchada allows you to override ATen operators at the C++ level for the `PrivateUse1` (MUSA) dispatch key. This is useful when you need: + +- Better performance than the default torch_musa implementation +- Custom behavior for specific operators +- Workarounds for torch_musa bugs + +## Quick Start + +### 1. Enable C++ Ops + +```bash +export TORCHADA_ENABLE_CPP_OPS=1 +``` + +### 2. Write Your Kernel + +Edit `src/torchada/csrc/musa_ops.mu`: + +```cpp +#include "ops.h" +#include + +namespace torchada { + +template +__global__ void my_kernel(scalar_t* output, const scalar_t* input, int64_t n) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = /* your computation */; + } +} + +at::Tensor my_op_impl(const at::Tensor& self) { + log_op_call("my_op"); + + auto input = self.contiguous(); + auto output = at::empty_like(input); + if (input.numel() == 0) return output; + + musaStream_t stream = at::musa::getCurrentMUSAStream(); + const int64_t n = input.numel(); + const int threads = 256; + const int blocks = (n + threads - 1) / threads; + + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "my_op", [&] { + my_kernel<<>>( + output.data_ptr(), + input.data_ptr(), + n); + }); + + return output; +} + +} // namespace torchada + +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + // Check env var at registration time - allows disabling via + // TORCHADA_DISABLE_OP_OVERRIDE_my_op=1 + if (torchada::is_override_enabled("my_op")) { + m.impl("my_op", torchada::my_op_impl); + } +} +``` + +### 3. Test Your Kernel + +```bash +TORCHADA_ENABLE_CPP_OPS=1 TORCHADA_DEBUG_CPP_OPS=1 python -c " +import torch +import torchada + +x = torch.randn(1000, device='cuda') +y = torch.neg(x) # Should print '[torchada] neg called' +print('Result:', y.cpu()[:5]) +" +``` + +## File Structure + +| File | Purpose | +|------|---------| +| `src/torchada/csrc/ops.h` | Header with utilities (`log_op_call`, `is_override_enabled`) | +| `src/torchada/csrc/ops.cpp` | Python bindings and C++-only operator overrides | +| `src/torchada/csrc/musa_ops.mu` | MUSA kernel implementations | + +## Environment Variables + +| Variable | Description | +|----------|-------------| +| `TORCHADA_ENABLE_CPP_OPS=1` | Enable C++ operator overrides | +| `TORCHADA_CPP_OPS_VERBOSE=1` | Show compilation output | +| `TORCHADA_DEBUG_CPP_OPS=1` | Log operator calls to stdout | +| `TORCHADA_DISABLE_OP_OVERRIDE_=1` | Disable specific operator override | +| `MTGPU_TARGET=mp_XX` | Override GPU architecture detection | + +### Disabling Specific Operators + +To disable a specific operator override at runtime, set the environment variable before importing torchada: + +```bash +# Disable the 'neg' operator override, use torch_musa's default instead +TORCHADA_ENABLE_CPP_OPS=1 TORCHADA_DISABLE_OP_OVERRIDE_neg=1 python my_script.py +``` + +**Important**: The operator name in the environment variable should match the name passed to `is_override_enabled()` in the C++ code. For example, if the code uses `is_override_enabled("neg")`, set `TORCHADA_DISABLE_OP_OVERRIDE_neg=1`. + +This check happens at **registration time** (when the extension is loaded), not at runtime. Once the extension is loaded, the operator registrations are fixed. + +## GPU Architecture + +torchada auto-detects the GPU architecture using `musaInfo`: + +| GPU | Compute Capability | Architecture | +|-----|-------------------|--------------| +| MTT S80 | 2.1 | mp_21 | +| MTT S4000 | 2.2 | mp_22 | +| MTT S5000 | 3.1 | mp_31 | + +Override with: `export MTGPU_TARGET=mp_22` + +## Best Practices + +### Avoid Infinite Recursion + +When overriding an operator, don't call the same operator: + +```cpp +// BAD - causes infinite recursion +at::Tensor bad_neg_impl(const at::Tensor& self) { + return -self; // Calls aten::neg again! +} + +// GOOD - use lower-level primitives +at::Tensor good_neg_impl(const at::Tensor& self) { + auto output = at::empty_like(self); + // Launch custom kernel or use in-place ops + return output; +} +``` + +### Handle Edge Cases + +```cpp +at::Tensor my_impl(const at::Tensor& self) { + auto input = self.contiguous(); // Ensure contiguous + if (input.numel() == 0) { + return at::empty_like(input); // Handle empty tensors + } + // ... kernel launch +} +``` + +### Check for Errors + +```cpp +musaError_t err = musaGetLastError(); +if (err != musaSuccess) { + TORCH_CHECK(false, "MUSA kernel failed: ", musaGetErrorString(err)); +} +``` + +### Use Type Dispatching + +```cpp +AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, + input.scalar_type(), "my_kernel", [&] { + my_kernel<<>>(...); + }); +``` + +## Overridable ATen Operators + +Any ATen operator with a `PrivateUse1` dispatch can be overridden. Common categories: + +### Unary Operations +`abs`, `neg`, `exp`, `exp2`, `log`, `log2`, `log10`, `sqrt`, `rsqrt`, `ceil`, `floor`, `round`, `trunc`, `sign`, `sin`, `cos`, `tan`, `asin`, `acos`, `atan`, `sinh`, `cosh`, `tanh`, `sigmoid`, `erf`, `erfc`, `reciprocal`, `bitwise_not` + +### Binary Operations +`add`, `sub`, `mul`, `div`, `pow`, `fmod`, `remainder`, `maximum`, `minimum`, `atan2`, `bitwise_and`, `bitwise_or`, `bitwise_xor`, `logical_and`, `logical_or`, `logical_xor` + +### Reduction Operations +`sum`, `prod`, `mean`, `std`, `var`, `max`, `min`, `argmax`, `argmin`, `all`, `any`, `norm`, `logsumexp` + +### Matrix Operations +`mm`, `bmm`, `addmm`, `addmv`, `addr`, `matmul`, `dot`, `mv`, `ger`, `linear` + +### Activation Functions +`relu`, `relu_`, `leaky_relu`, `gelu`, `silu`, `mish`, `hardswish`, `hardsigmoid`, `softplus`, `softshrink`, `threshold` + +### Normalization +`batch_norm`, `layer_norm`, `group_norm`, `instance_norm`, `local_response_norm` + +### Pooling +`max_pool1d`, `max_pool2d`, `max_pool3d`, `avg_pool1d`, `avg_pool2d`, `avg_pool3d`, `adaptive_max_pool2d`, `adaptive_avg_pool2d` + +### Convolution +`conv1d`, `conv2d`, `conv3d`, `conv_transpose1d`, `conv_transpose2d`, `conv_transpose3d` + +### Memory Operations +`copy_`, `clone`, `contiguous`, `fill_`, `zero_`, `ones_like`, `zeros_like`, `empty_like` + +### Indexing +`index`, `index_put_`, `gather`, `scatter`, `scatter_add`, `masked_fill`, `masked_select`, `where` + +### Shape Operations +`view`, `reshape`, `transpose`, `permute`, `squeeze`, `unsqueeze`, `expand`, `repeat`, `cat`, `stack`, `split`, `chunk` + +To find the exact operator signature, use: + +```python +import torch +# Search for specific operator: +for s in torch._C._jit_get_all_schemas(): + if 'neg' in str(s): + print(s) +``` + +## Complete Example + +See `src/torchada/csrc/musa_ops.mu` for a complete working example that overrides `aten::neg`. + +## Debugging + +1. **Verify kernel is called**: Set `TORCHADA_DEBUG_CPP_OPS=1` +2. **Check compilation**: Set `TORCHADA_CPP_OPS_VERBOSE=1` +3. **Clear cache**: `rm -rf ~/.cache/torch_extensions/*/torchada_cpp_ops` +4. **Check architecture**: Run `musaInfo | grep "compute capability"` + diff --git a/pyproject.toml b/pyproject.toml index cbafdb9..ba24659 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,9 @@ Repository = "https://github.com/MooreThreads/torchada" [tool.setuptools.packages.find] where = ["src"] +[tool.setuptools.package-data] +torchada = ["csrc/*.h", "csrc/*.cpp", "csrc/*.cu", "csrc/*.mu", "csrc/*.cuh", "csrc/*.muh"] + [tool.black] line-length = 100 target-version = ["py38", "py39", "py310", "py311", "py312"] diff --git a/src/torchada/__init__.py b/src/torchada/__init__.py index 33d2e61..52107de 100644 --- a/src/torchada/__init__.py +++ b/src/torchada/__init__.py @@ -43,6 +43,11 @@ # Automatically apply patches on import apply_patches() +# Load C++ operator overrides if enabled via TORCHADA_ENABLE_CPP_OPS=1 +from ._cpp_ops import load_cpp_ops + +load_cpp_ops() + def get_version() -> str: """Return the version of torchada.""" diff --git a/src/torchada/_cpp_ops.py b/src/torchada/_cpp_ops.py new file mode 100644 index 0000000..8275d4c --- /dev/null +++ b/src/torchada/_cpp_ops.py @@ -0,0 +1,176 @@ +# Copyright (c) torchada contributors +# SPDX-License-Identifier: MIT +""" +C++ operator overrides infrastructure for torchada. + +This module handles building and loading C++ extensions that can override +ATen operator implementations for the PrivateUse1 (MUSA) dispatch key. + +Usage: + # Enable C++ ops by setting environment variable + export TORCHADA_ENABLE_CPP_OPS=1 + + # Then import torchada as usual + import torchada + + # Or explicitly load + from torchada._cpp_ops import load_cpp_ops + load_cpp_ops() +""" + +import os +import subprocess +from typing import Optional + +_cpp_ops_module: Optional[object] = None + + +def _detect_musa_arch() -> str: + """ + Detect MUSA GPU architecture from compute capability. + + Uses musaInfo to get compute capability and maps it to architecture: + - 2.1 -> mp_21 (MTT S80) + - 2.2 -> mp_22 (MTT S4000) + - 3.1 -> mp_31 (MTT S5000) + + Returns: + Architecture string like "mp_21", "mp_22", or "mp_31". + Defaults to "mp_31" if detection fails. + """ + try: + result = subprocess.run( + ["musaInfo"], + capture_output=True, + text=True, + timeout=5, + ) + for line in result.stdout.splitlines(): + if "compute capability:" in line.lower(): + # Parse "compute capability: 2.1" + parts = line.split(":") + if len(parts) >= 2: + version = parts[1].strip() + # Convert "2.1" -> "mp_21", "3.1" -> "mp_31" + version_parts = version.split(".") + if len(version_parts) >= 2: + major = version_parts[0].strip() + minor = version_parts[1].strip() + return f"mp_{major}{minor}" + except (subprocess.TimeoutExpired, FileNotFoundError, Exception): + pass + + # Default fallback + return "mp_31" + + +def load_cpp_ops(force_reload: bool = False) -> Optional[object]: + """ + Load the C++ operator overrides extension. + + The extension is only loaded if: + 1. Running on MUSA platform + 2. TORCHADA_ENABLE_CPP_OPS=1 environment variable is set + + Args: + force_reload: If True, reload the extension even if already loaded. + + Returns: + The loaded extension module, or None if not loaded. + """ + global _cpp_ops_module + + if _cpp_ops_module is not None and not force_reload: + return _cpp_ops_module + + # Check if enabled via environment variable + if os.environ.get("TORCHADA_ENABLE_CPP_OPS") != "1": + return None + + # Check if on MUSA platform + from ._platform import is_musa_platform + + if not is_musa_platform(): + return None + + try: + import os.path as osp + + csrc_dir = osp.join(osp.dirname(__file__), "csrc") + + # Collect all source files + cpp_sources = [] + musa_sources = [] + + for fname in os.listdir(csrc_dir): + fpath = osp.join(csrc_dir, fname) + if fname.endswith(".cpp"): + cpp_sources.append(fpath) + elif fname.endswith((".cu", ".mu")): + musa_sources.append(fpath) + + if not cpp_sources and not musa_sources: + import warnings + + warnings.warn("torchada C++ ops: no source files found") + return None + + verbose = os.environ.get("TORCHADA_CPP_OPS_VERBOSE") == "1" + all_sources = cpp_sources + musa_sources + + # Use MUSA extension loader if we have MUSA sources, otherwise use torch's + if musa_sources: + # Use torchada's load which handles MUSA properly + from .utils.cpp_extension import load + + # Get MUSA architecture flags + # Use MTGPU_TARGET env var if set, otherwise auto-detect from GPU + extra_cuda_cflags = [] + mtgpu_target = os.environ.get("MTGPU_TARGET", "") + if not mtgpu_target: + mtgpu_target = _detect_musa_arch() + extra_cuda_cflags.append(f"--offload-arch={mtgpu_target}") + + _cpp_ops_module = load( + name="torchada_cpp_ops", + sources=all_sources, + extra_include_paths=[csrc_dir], + extra_cuda_cflags=extra_cuda_cflags, + verbose=verbose, + ) + else: + # Pure C++ extension - use torch's loader directly + from torch.utils.cpp_extension import load + + _cpp_ops_module = load( + name="torchada_cpp_ops", + sources=all_sources, + extra_include_paths=[csrc_dir], + verbose=verbose, + ) + + _cpp_ops_module._mark_loaded() + return _cpp_ops_module + + except Exception as e: + import warnings + + warnings.warn(f"Failed to load torchada C++ ops: {e}") + return None + + +def is_loaded() -> bool: + """Check if C++ ops extension is loaded.""" + return _cpp_ops_module is not None + + +def get_version() -> Optional[str]: + """Get C++ ops extension version.""" + if _cpp_ops_module is None: + return None + return _cpp_ops_module.get_version() + + +def get_module() -> Optional[object]: + """Get the loaded C++ ops module.""" + return _cpp_ops_module diff --git a/src/torchada/csrc/musa_ops.mu b/src/torchada/csrc/musa_ops.mu new file mode 100644 index 0000000..a03f325 --- /dev/null +++ b/src/torchada/csrc/musa_ops.mu @@ -0,0 +1,87 @@ +// torchada MUSA operator overrides +// +// This file contains MUSA kernel implementations that can override torch_musa's +// default ATen operator implementations. +// +// NOTE: No operators are overridden by default. The implementations below serve +// as examples. To activate an override, uncomment the corresponding m.impl() +// line in the TORCH_LIBRARY_IMPL block at the bottom of this file. + +#include "ops.h" +#include + +namespace torchada { + +// ============================================================================ +// Example: MUSA kernel for neg (negation) +// This demonstrates how to override aten::neg for PrivateUse1 (MUSA) tensors +// ============================================================================ + +template +__global__ void neg_kernel( + scalar_t* __restrict__ output, + const scalar_t* __restrict__ input, + int64_t numel) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) { + output[idx] = -input[idx]; + } +} + +at::Tensor neg_musa_impl(const at::Tensor& self) { + log_op_call("neg"); + + // Ensure contiguous tensor + auto self_contig = self.contiguous(); + + // Allocate output tensor + auto output = at::empty_like(self_contig); + + if (self_contig.numel() == 0) { + return output; + } + + // Get MUSA stream + musaStream_t stream = at::musa::getCurrentMUSAStream(); + + // Launch kernel + const int64_t numel = self_contig.numel(); + const int threads = 256; + const int blocks = (numel + threads - 1) / threads; + + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, + self_contig.scalar_type(), "neg_musa", [&] { + neg_kernel<<>>( + output.data_ptr(), + self_contig.data_ptr(), + numel); + }); + + // Check for launch errors + musaError_t err = musaGetLastError(); + if (err != musaSuccess) { + TORCH_CHECK(false, "MUSA kernel launch failed: ", musaGetErrorString(err)); + } + + return output; +} + +} // namespace torchada + +// ============================================================================ +// Register operator overrides for PrivateUse1 (MUSA) +// +// Each operator checks TORCHADA_DISABLE_OP_OVERRIDE_=1 at registration +// time. If set, the override is not registered and torch_musa's default +// implementation is used. +// +// Uncomment m.impl() lines to activate custom implementations. +// ============================================================================ + +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + // Example: Register neg override only if not disabled + // if (torchada::is_override_enabled("neg")) { + // m.impl("neg", torchada::neg_musa_impl); + // } +} diff --git a/src/torchada/csrc/ops.cpp b/src/torchada/csrc/ops.cpp new file mode 100644 index 0000000..6c7d955 --- /dev/null +++ b/src/torchada/csrc/ops.cpp @@ -0,0 +1,78 @@ +// torchada C++ operator overrides - Main source file +// +// This file contains the operator registration infrastructure and example overrides. +// Custom operator implementations can be added here or in separate files. +// +// To add a new operator override: +// 1. Write the implementation function +// 2. Register it using TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) +// +// Note: Operators registered here will override torch_musa's implementations. +// Use with caution and ensure correctness. + +#include "ops.h" + +namespace torchada { + +// ============================================================================ +// Example: Operator override template (commented out - for reference) +// ============================================================================ +// +// To override an ATen operator, follow this pattern: +// +// static at::Tensor custom_add_impl( +// const at::Tensor& self, +// const at::Tensor& other, +// const at::Scalar& alpha) { +// +// log_op_call("add.Tensor"); +// +// // Your custom implementation here +// // IMPORTANT: Avoid calling the same operator to prevent infinite recursion +// // Use in-place operations or lower-level primitives instead +// auto result = at::empty_like(self); +// result.copy_(self); +// result.add_(other, alpha); +// return result; +// } +// +// Then register it: +// TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { +// m.impl("add.Tensor", custom_add_impl); +// } + +// ============================================================================ +// Utility functions exposed to Python +// ============================================================================ + +static bool cpp_ops_loaded = false; + +bool is_loaded() { + return cpp_ops_loaded; +} + +const char* get_version() { + return VERSION; +} + +void mark_loaded() { + cpp_ops_loaded = true; +} + +} // namespace torchada + +// ============================================================================ +// Python bindings +// ============================================================================ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "torchada C++ operator overrides"; + + m.def("is_loaded", &torchada::is_loaded, + "Check if C++ ops extension is loaded"); + m.def("get_version", &torchada::get_version, + "Get the C++ ops extension version"); + m.def("_mark_loaded", &torchada::mark_loaded, + "Mark the extension as loaded (internal use)"); +} + diff --git a/src/torchada/csrc/ops.h b/src/torchada/csrc/ops.h new file mode 100644 index 0000000..b473e48 --- /dev/null +++ b/src/torchada/csrc/ops.h @@ -0,0 +1,67 @@ +// torchada C++ operator overrides +// +// This header provides the infrastructure for registering custom ATen operator +// implementations that override the default PrivateUse1 (MUSA) implementations. +// +// Usage: +// 1. Include this header in your .cpp or .mu file +// 2. Use TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) to register overrides +// 3. Use is_override_enabled("op_name") to check if override should be registered +// 4. The extension will be built and loaded automatically by torchada +// +// Example: +// #include "ops.h" +// +// at::Tensor my_custom_add(const at::Tensor& self, const at::Tensor& other, +// const at::Scalar& alpha) { +// log_op_call("add.Tensor"); // Optional: log when called +// // Custom implementation +// auto result = at::empty_like(self); +// result.copy_(self); +// result.add_(other, alpha); +// return result; +// } +// +// TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { +// // Check env var at registration time - if disabled, don't register +// if (torchada::is_override_enabled("add")) { +// m.impl("add.Tensor", my_custom_add); +// } +// } +// +// Environment variables: +// TORCHADA_DISABLE_OP_OVERRIDE_=1 - Disable specific operator override +// TORCHADA_DEBUG_CPP_OPS=1 - Log operator calls to stdout + +#pragma once + +#include +#include + +namespace torchada { + +// Version information +constexpr const char* VERSION = "0.1.0"; + +// Check if operator override is enabled via environment variable +inline bool is_override_enabled(const char* op_name) { + // Check TORCHADA_DISABLE_OP_OVERRIDE_ environment variable + std::string env_var = "TORCHADA_DISABLE_OP_OVERRIDE_"; + env_var += op_name; + const char* val = std::getenv(env_var.c_str()); + if (val != nullptr && std::string(val) == "1") { + return false; + } + return true; +} + +// Logging helper for debugging +inline void log_op_call(const char* op_name) { + const char* debug = std::getenv("TORCHADA_DEBUG_CPP_OPS"); + if (debug != nullptr && std::string(debug) == "1") { + std::cout << "[torchada] " << op_name << " called" << std::endl; + } +} + +} // namespace torchada + diff --git a/src/torchada/utils/cpp_extension.py b/src/torchada/utils/cpp_extension.py index 421f12f..5d15147 100644 --- a/src/torchada/utils/cpp_extension.py +++ b/src/torchada/utils/cpp_extension.py @@ -638,17 +638,20 @@ def load( import torch_musa.utils.musa_extension as musa_ext # Use MUSA's load function if available + # Note: MUSA uses different parameter names: + # extra_cuda_cflags -> extra_musa_cflags + # with_cuda -> with_musa if hasattr(musa_ext, "load"): return musa_ext.load( name=name, sources=sources, extra_cflags=extra_cflags, - extra_cuda_cflags=extra_cuda_cflags, + extra_musa_cflags=extra_cuda_cflags, extra_ldflags=extra_ldflags, extra_include_paths=extra_include_paths, build_directory=build_directory, verbose=verbose, - with_cuda=with_cuda, + with_musa=with_cuda, is_python_module=is_python_module, is_standalone=is_standalone, keep_intermediates=keep_intermediates, diff --git a/tests/test_cuda_patching.py b/tests/test_cuda_patching.py index 43312b0..1cb0e79 100644 --- a/tests/test_cuda_patching.py +++ b/tests/test_cuda_patching.py @@ -1544,3 +1544,60 @@ def test_cudart_in_dir(self): pytest.skip("Only applicable on MUSA platform") assert "cudart" in dir(torch.cuda) + + +class TestCppOpsInfrastructure: + """Test C++ operator overrides infrastructure.""" + + def test_cpp_ops_module_exists(self): + """Test that the _cpp_ops module exists and can be imported.""" + from torchada import _cpp_ops + + assert hasattr(_cpp_ops, "load_cpp_ops") + assert hasattr(_cpp_ops, "is_loaded") + assert hasattr(_cpp_ops, "get_version") + assert hasattr(_cpp_ops, "get_module") + + def test_cpp_ops_not_loaded_by_default(self): + """Test that C++ ops are not loaded by default.""" + from torchada import _cpp_ops + + # Without TORCHADA_ENABLE_CPP_OPS=1, should not be loaded + # Note: This test may be affected by other tests that load the module + # So we just check the functions exist and are callable + assert callable(_cpp_ops.is_loaded) + assert callable(_cpp_ops.get_version) + + def test_cpp_ops_source_files_exist(self): + """Test that the C++ source files are packaged correctly.""" + import os.path as osp + + import torchada + + csrc_dir = osp.join(osp.dirname(torchada.__file__), "csrc") + assert osp.isdir(csrc_dir), f"csrc directory not found: {csrc_dir}" + + ops_h = osp.join(csrc_dir, "ops.h") + ops_cpp = osp.join(csrc_dir, "ops.cpp") + + assert osp.isfile(ops_h), f"ops.h not found: {ops_h}" + assert osp.isfile(ops_cpp), f"ops.cpp not found: {ops_cpp}" + + def test_cpp_ops_header_content(self): + """Test that the C++ header has expected content.""" + import os.path as osp + + import torchada + + csrc_dir = osp.join(osp.dirname(torchada.__file__), "csrc") + ops_h = osp.join(csrc_dir, "ops.h") + + with open(ops_h, "r") as f: + content = f.read() + + # Check for expected content + assert "namespace torchada" in content + assert "TORCH_LIBRARY_IMPL" in content + assert "PrivateUse1" in content + assert "is_override_enabled" in content + assert "log_op_call" in content